1+ # or more contributor license agreements. See the NOTICE file
2+ # distributed with this work for additional information
3+ # regarding copyright ownership. The ASF licenses this file
4+ # to you under the Apache License, Version 2.0 (the
5+ # "License"); you may not use this file except in compliance
6+ # with the License. You may obtain a copy of the License at
7+ #
8+ # http://www.apache.org/licenses/LICENSE-2.0
9+ #
10+ # Unless required by applicable law or agreed to in writing,
11+ # software distributed under the License is distributed on an
12+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13+ # KIND, either express or implied. See the License for the
14+ # specific language governing permissions and limitations
15+ # under the License.
16+ # pylint: disable=missing-function-docstring,missing-module-docstring
17+ import sys
18+
19+ import pytest
20+
21+ import tvm
22+ from tvm import tir
23+ from tvm .script import tir as T
24+ from tvm .tir .schedule .testing import verify_trace_roundtrip
25+
26+
27+ # fmt: off
28+ # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable
29+
30+ @T .prim_func
31+ def cuda_matmul (a : T .handle , b : T .handle , c : T .handle ) -> None : # pylint: disable=undefined-loop-variable
32+ A = T .match_buffer (a , [2048 , 2048 ], "float32" )
33+ B = T .match_buffer (b , [2048 , 2048 ], "float32" )
34+ C = T .match_buffer (c , [2048 , 2048 ], "float32" )
35+ for by in T .thread_binding (0 , 32 , thread = "blockIdx.y" ):
36+ for bx in T .thread_binding (0 , 32 , thread = "blockIdx.x" ):
37+ for vy in T .thread_binding (0 , 2 , thread = "vthread.y" ):
38+ for vx in T .thread_binding (0 , 2 , thread = "vthread.x" ):
39+ for ty in T .thread_binding (0 , 8 , thread = "threadIdx.y" ):
40+ for tx in T .thread_binding (0 , 8 , thread = "threadIdx.x" ):
41+ for k0 in T .serial (0 , 256 ):
42+ for k1 in T .unroll (0 , 8 ):
43+ for _ , i , j in T .grid (1 , 4 , 4 ):
44+ with T .block ("C" ):
45+ vi = T .axis .S (2048 , by * 64 + vy * 32 + ty * 4 + i )
46+ vj = T .axis .S (2048 , bx * 64 + vx * 32 + tx * 4 + j )
47+ vk = T .axis .R (2048 , k0 * 8 + k1 )
48+ T .reads ([C [vi , vj ], A [vi , vk ], B [vk , vj ]])
49+ T .writes ([C [vi , vj ]])
50+ with T .init ():
51+ C [vi , vj ] = 0.0
52+ C [vi , vj ] = C [vi , vj ] + A [vi , vk ] * B [vk , vj ]
53+
54+
55+ @T .prim_func
56+ def cuda_matmul_read_at_a (a : T .handle , b : T .handle , c : T .handle ) -> None :
57+ A = T .match_buffer (a , [2048 , 2048 ], dtype = "float32" )
58+ B = T .match_buffer (b , [2048 , 2048 ], dtype = "float32" )
59+ C = T .match_buffer (c , [2048 , 2048 ], dtype = "float32" )
60+ A_shared = T .alloc_buffer ([2048 , 2048 ], dtype = "float32" , scope = "shared" )
61+ for by in T .thread_binding (0 , 32 , thread = "blockIdx.y" ):
62+ for bx in T .thread_binding (0 , 32 , thread = "blockIdx.x" ):
63+ for vy in T .thread_binding (0 , 2 , thread = "vthread.y" ):
64+ for vx in T .thread_binding (0 , 2 , thread = "vthread.x" ):
65+ for ty in T .thread_binding (0 , 8 , thread = "threadIdx.y" ):
66+ for tx in T .thread_binding (0 , 8 , thread = "threadIdx.x" ):
67+ for k0 in T .serial (0 , 256 ):
68+ with T .block ("A_shared" ):
69+ v0 = T .axis .S (32 , by )
70+ v1 = T .axis .S (256 , k0 )
71+ T .reads ([A [v0 * 64 : v0 * 64 + 64 , v1 * 8 : v1 * 8 + 8 ]])
72+ T .writes ([A_shared [v0 * 64 : v0 * 64 + 64 , v1 * 8 : v1 * 8 + 8 ]])
73+ T .block_attr ({"auto_copy" :1 })
74+ for ax0 , ax1 in T .grid (64 , 8 ):
75+ A_shared [v0 * 64 + ax0 , v1 * 8 + ax1 ] = A [v0 * 64 + ax0 , v1 * 8 + ax1 ]
76+ for k1 in T .unroll (0 , 8 ):
77+ for v_ , i , j in T .grid (1 , 4 , 4 ):
78+ with T .block ("C" ):
79+ vi = T .axis .S (2048 , by * 64 + vy * 32 + ty * 4 + i )
80+ vj = T .axis .S (2048 , bx * 64 + vx * 32 + tx * 4 + j )
81+ vk = T .axis .R (2048 , k0 * 8 + k1 )
82+ T .reads ([C [vi , vj ], A_shared [vi , vk ], B [vk , vj ]])
83+ T .writes ([C [vi , vj ]])
84+ with T .init ():
85+ C [vi , vj ] = T .float32 (0 )
86+ C [vi , vj ] = C [vi , vj ] + A_shared [vi , vk ] * B [vk , vj ]
87+
88+
89+ @T .prim_func
90+ def cuda_matmul_read_at_ab (a : T .handle , b : T .handle , c : T .handle ) -> None :
91+ A = T .match_buffer (a , [2048 , 2048 ], dtype = "float32" )
92+ B = T .match_buffer (b , [2048 , 2048 ], dtype = "float32" )
93+ C = T .match_buffer (c , [2048 , 2048 ], dtype = "float32" )
94+ A_shared = T .alloc_buffer ([2048 , 2048 ], dtype = "float32" , scope = "shared" )
95+ B_shared = T .alloc_buffer ([2048 , 2048 ], dtype = "float32" , scope = "shared" )
96+ for by in T .thread_binding (0 , 32 , thread = "blockIdx.y" ):
97+ for bx in T .thread_binding (0 , 32 , thread = "blockIdx.x" ):
98+ for vy in T .thread_binding (0 , 2 , thread = "vthread.y" ):
99+ for vx in T .thread_binding (0 , 2 , thread = "vthread.x" ):
100+ for ty in T .thread_binding (0 , 8 , thread = "threadIdx.y" ):
101+ for tx in T .thread_binding (0 , 8 , thread = "threadIdx.x" ):
102+ for k0 in T .serial (0 , 256 ):
103+ with T .block ("A_shared" ):
104+ v0 = T .axis .S (32 , by )
105+ v1 = T .axis .S (256 , k0 )
106+ T .reads ([A [v0 * 64 : v0 * 64 + 64 , v1 * 8 : v1 * 8 + 8 ]])
107+ T .writes ([A_shared [v0 * 64 : v0 * 64 + 64 , v1 * 8 : v1 * 8 + 8 ]])
108+ T .block_attr ({"auto_copy" :1 })
109+ for ax0 , ax1 in T .grid (64 , 8 ):
110+ A_shared [v0 * 64 + ax0 , v1 * 8 + ax1 ] = A [v0 * 64 + ax0 , v1 * 8 + ax1 ]
111+ with T .block ("B_shared" ):
112+ v0 = T .axis .S (256 , k0 )
113+ v1 = T .axis .S (32 , bx )
114+ T .reads ([B [v0 * 8 : v0 * 8 + 8 , v1 * 64 : v1 * 64 + 64 ]])
115+ T .writes ([B_shared [v0 * 8 : v0 * 8 + 8 , v1 * 64 : v1 * 64 + 64 ]])
116+ T .block_attr ({"auto_copy" :1 })
117+ for ax0 , ax1 in T .grid (8 , 64 ):
118+ B_shared [v0 * 8 + ax0 , v1 * 64 + ax1 ] = B [v0 * 8 + ax0 , v1 * 64 + ax1 ]
119+ for k1 in T .unroll (0 , 8 ):
120+ for v_ , i , j in T .grid (1 , 4 , 4 ):
121+ with T .block ("C" ):
122+ vi = T .axis .S (2048 , by * 64 + vy * 32 + ty * 4 + i )
123+ vj = T .axis .S (2048 , bx * 64 + vx * 32 + tx * 4 + j )
124+ vk = T .axis .R (2048 , k0 * 8 + k1 )
125+ T .reads ([C [vi , vj ], A_shared [vi , vk ], B_shared [vk , vj ]])
126+ T .writes ([C [vi , vj ]])
127+ with T .init ():
128+ C [vi , vj ] = T .float32 (0 )
129+ C [vi , vj ] = C [vi , vj ] + A_shared [vi , vk ] * B_shared [vk , vj ]
130+
131+ @T .prim_func
132+ def cuda_matmul_write_at_c (a : T .handle , b : T .handle , c : T .handle ) -> None :
133+ A = T .match_buffer (a , [2048 , 2048 ], dtype = "float32" )
134+ B = T .match_buffer (b , [2048 , 2048 ], dtype = "float32" )
135+ C = T .match_buffer (c , [2048 , 2048 ], dtype = "float32" )
136+ A_shared = T .alloc_buffer ([2048 , 2048 ], dtype = "float32" , scope = "shared" )
137+ B_shared = T .alloc_buffer ([2048 , 2048 ], dtype = "float32" , scope = "shared" )
138+ C_shared = T .alloc_buffer ([2048 , 2048 ], dtype = "float32" , scope = "shared" )
139+ for by in T .thread_binding (0 , 32 , thread = "blockIdx.y" ):
140+ for bx in T .thread_binding (0 , 32 , thread = "blockIdx.x" ):
141+ for vy in T .thread_binding (0 , 2 , thread = "vthread.y" ):
142+ for vx in T .thread_binding (0 , 2 , thread = "vthread.x" ):
143+ for ty in T .thread_binding (0 , 8 , thread = "threadIdx.y" ):
144+ for tx in T .thread_binding (0 , 8 , thread = "threadIdx.x" ):
145+ for k0 in T .serial (0 , 256 ):
146+ with T .block ("A_shared" ):
147+ v0 = T .axis .S (32 , by )
148+ v1 = T .axis .S (256 , k0 )
149+ T .reads ([A [v0 * 64 : v0 * 64 + 64 , v1 * 8 : v1 * 8 + 8 ]])
150+ T .writes ([A_shared [v0 * 64 : v0 * 64 + 64 , v1 * 8 : v1 * 8 + 8 ]])
151+ T .block_attr ({"auto_copy" :1 })
152+ for ax0 , ax1 in T .grid (64 , 8 ):
153+ A_shared [v0 * 64 + ax0 , v1 * 8 + ax1 ] = A [v0 * 64 + ax0 , v1 * 8 + ax1 ]
154+ with T .block ("B_shared" ):
155+ v0 = T .axis .S (256 , k0 )
156+ v1 = T .axis .S (32 , bx )
157+ T .reads ([B [v0 * 8 : v0 * 8 + 8 , v1 * 64 : v1 * 64 + 64 ]])
158+ T .writes ([B_shared [v0 * 8 : v0 * 8 + 8 , v1 * 64 : v1 * 64 + 64 ]])
159+ T .block_attr ({"auto_copy" :1 })
160+ for ax0 , ax1 in T .grid (8 , 64 ):
161+ B_shared [v0 * 8 + ax0 , v1 * 64 + ax1 ] = B [v0 * 8 + ax0 , v1 * 64 + ax1 ]
162+ for k1 in T .unroll (0 , 8 ):
163+ for v_ , i , j in T .grid (1 , 4 , 4 ):
164+ with T .block ("C" ):
165+ vi = T .axis .S (2048 , by * 64 + vy * 32 + ty * 4 + i )
166+ vj = T .axis .S (2048 , bx * 64 + vx * 32 + tx * 4 + j )
167+ vk = T .axis .R (2048 , k0 * 8 + k1 )
168+ T .reads ([C_shared [vi , vj ], A_shared [vi , vk ], B_shared [vk , vj ]])
169+ T .writes ([C_shared [vi , vj ]])
170+ with T .init ():
171+ C_shared [vi , vj ] = T .float32 (0 )
172+ C_shared [vi , vj ] = C_shared [vi , vj ] + A_shared [vi , vk ] * B_shared [vk , vj ]
173+ with T .block ("C_shared" ):
174+ v0 = T .axis .S (32 , by )
175+ v1 = T .axis .S (32 , bx )
176+ T .reads ([C_shared [v0 * 64 : v0 * 64 + 64 , v1 * 64 : v1 * 64 + 64 ]])
177+ T .writes ([C [v0 * 64 : v0 * 64 + 64 , v1 * 64 : v1 * 64 + 64 ]])
178+ T .block_attr ({"auto_copy" :1 })
179+ for ax0 , ax1 in T .grid (64 , 64 ):
180+ C [v0 * 64 + ax0 , v1 * 64 + ax1 ] = C_shared [v0 * 64 + ax0 , v1 * 64 + ax1 ]
181+
182+
183+ # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable
184+ # fmt: on
185+
186+
187+ def test_read_at_global_to_shared_a ():
188+ sch = tir .Schedule (cuda_matmul , debug_mask = "all" )
189+ block = sch .get_block ("C" )
190+ # pylint: disable=invalid-name
191+ _by , _bx , _vy , _vx , _ty , _tx , k0 , _k1 , _ , _i , _j = sch .get_loops (block )
192+ # pylint: enable=invalid-name
193+ sch .read_at (k0 , block , 1 , "shared" )
194+ tvm .ir .assert_structural_equal (sch .mod ["main" ], cuda_matmul_read_at_a )
195+ verify_trace_roundtrip (sch , cuda_matmul )
196+
197+
198+ def test_read_at_global_to_shared_ab ():
199+ sch = tir .Schedule (cuda_matmul_read_at_a , debug_mask = "all" )
200+ block = sch .get_block ("C" )
201+ # pylint: disable=invalid-name
202+ _by , _bx , _vy , _vx , _ty , _tx , k0 , _k1 , _ , _i , _j = sch .get_loops (block )
203+ # pylint: enable=invalid-name
204+ sch .read_at (k0 , block , 2 , "shared" )
205+ tvm .ir .assert_structural_equal (sch .mod ["main" ], cuda_matmul_read_at_ab )
206+ verify_trace_roundtrip (sch , cuda_matmul_read_at_a )
207+
208+
209+ def test_read_at_local_to_shared_c ():
210+ sch = tir .Schedule (cuda_matmul_read_at_ab , debug_mask = "all" )
211+ block = sch .get_block ("C" )
212+ # pylint: disable=invalid-name
213+ _by , _bx , _vy , _vx , _ty , tx , _k0 , _k1 , _ , _i , _j = sch .get_loops (block )
214+ # pylint: enable=invalid-name
215+ sch .write_at (tx , block , 0 , "shared" )
216+ tvm .ir .assert_structural_equal (sch .mod ["main" ], cuda_matmul_write_at_c )
217+ verify_trace_roundtrip (sch , cuda_matmul_read_at_ab )
218+
219+
220+ if __name__ == "__main__" :
221+ sys .exit (pytest .main ([__file__ ] + sys .argv [1 :]))
0 commit comments