@@ -265,5 +265,86 @@ def constant_binds_wrapped():
265265 assert_structural_equal (constant_binds , constant_binds_wrapped )
266266
267267
268+ def test_func_call ():
269+ def shared_16x16_to_ldmatrix_32x8_layout (i , j ):
270+ thread_id = (i % 8 ) * 4 + (j % 8 ) // 2
271+ return thread_id , (j // 8 ) * 4 + (i // 8 ) * 2 + (j % 2 )
272+
273+ @T .prim_func
274+ def mma_sync_m16n16k16_desc (a : T .handle , b : T .handle , c : T .handle ) -> None :
275+ A = T .match_buffer (a , (32 , 8 ), "float16" , align = 128 , offset_factor = 16 , scope = "warp" )
276+ B = T .match_buffer (b , (32 , 8 ), "float16" , align = 128 , offset_factor = 16 , scope = "warp" )
277+ C = T .match_buffer (c , (32 , 8 ), "float16" , align = 128 , offset_factor = 16 , scope = "warp" )
278+
279+ with T .block ("root" ):
280+ T .reads (C [0 :32 , 0 :8 ], A [0 :32 , 0 :8 ], B [0 :32 , 0 :8 ])
281+ T .writes (C [0 :32 , 0 :8 ])
282+ for i , j , k in T .grid (16 , 16 , 16 ):
283+ with T .block ("C" ):
284+ i , j , k = T .axis .remap ("SSR" , [i , j , k ])
285+ thread_id_C , local_id_C = shared_16x16_to_ldmatrix_32x8_layout (i , j )
286+ thread_id_A , local_id_A = shared_16x16_to_ldmatrix_32x8_layout (i , k )
287+ thread_id_B , local_id_B = shared_16x16_to_ldmatrix_32x8_layout (k , j )
288+
289+ T .reads (
290+ C [thread_id_C , local_id_C ],
291+ A [thread_id_A , local_id_A ],
292+ B [thread_id_B , local_id_B ],
293+ )
294+ T .writes (C [thread_id_C , local_id_C ])
295+
296+ C [thread_id_C , local_id_C ] += (
297+ A [thread_id_A , local_id_A ] * B [thread_id_B , local_id_B ]
298+ )
299+
300+ @T .prim_func
301+ def mma_sync_m16n16k16_desc_manual (a : T .handle , b : T .handle , c : T .handle ) -> None :
302+ A = T .match_buffer (a , (32 , 8 ), "float16" , align = 128 , offset_factor = 16 , scope = "warp" )
303+ B = T .match_buffer (b , (32 , 8 ), "float16" , align = 128 , offset_factor = 16 , scope = "warp" )
304+ C = T .match_buffer (c , (32 , 8 ), "float16" , align = 128 , offset_factor = 16 , scope = "warp" )
305+
306+ with T .block ("root" ):
307+ T .reads (C [0 :32 , 0 :8 ], A [0 :32 , 0 :8 ], B [0 :32 , 0 :8 ])
308+ T .writes (C [0 :32 , 0 :8 ])
309+ for i , j , k in T .grid (16 , 16 , 16 ):
310+ with T .block ("C" ):
311+ i , j , k = T .axis .remap ("SSR" , [i , j , k ])
312+ T .reads (
313+ C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ],
314+ A [i % 8 * 4 + k % 8 // 2 , k // 8 * 4 + i // 8 * 2 + k % 2 ],
315+ B [k % 8 * 4 + j % 8 // 2 , j // 8 * 4 + k // 8 * 2 + j % 2 ],
316+ )
317+ T .writes (C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ])
318+ C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ] = (
319+ C [i % 8 * 4 + j % 8 // 2 , j // 8 * 4 + i // 8 * 2 + j % 2 ]
320+ + A [i % 8 * 4 + k % 8 // 2 , k // 8 * 4 + i // 8 * 2 + k % 2 ]
321+ * B [k % 8 * 4 + j % 8 // 2 , j // 8 * 4 + k // 8 * 2 + j % 2 ]
322+ )
323+
324+ assert_structural_equal (mma_sync_m16n16k16_desc , mma_sync_m16n16k16_desc_manual )
325+
326+ # The following is an example of an error message from calling an invalid function
327+
328+ # error: Error occured when invoking the function sqrt:
329+ # loop of ufunc does not support argument 0 of type Var which has no callable sqrt method
330+ # --> test_tvmscript_syntax_sugar.py:334:19
331+ # |
332+ # 334 | ind = sqrt(i)
333+ # | ^^^^^^^
334+ # note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
335+
336+ # Uncomment to see the error above.
337+ # def sqrt(x):
338+ # import numpy as np
339+ # return np.sqrt(x)
340+
341+ # @T.prim_func
342+ # def loop(a: T.handle) -> None:
343+ # A = T.match_buffer(a, (128,))
344+ # for i in T.serial(128):
345+ # ind = sqrt(i)
346+ # A[i] = A[ind]
347+
348+
268349if __name__ == "__main__" :
269350 sys .exit (pytest .main ([__file__ ] + sys .argv [1 :]))
0 commit comments