@@ -391,63 +391,63 @@ def test_meta_schedule_te2primfunc_argument_order_and_lowering():
391391 class _fused_layout_transform :
392392 @T .prim_func
393393 def main ( # type: ignore
394- placeholder : T .Buffer [(1 , 3 , 16 , 16 ), "float32" ], # type: ignore
395- T_layout_trans : T .Buffer [(1 , 1 , 16 , 16 , 3 ), "float32" ], # type: ignore
394+ placeholder : T .Buffer [(T . int64 ( 1 ), T . int64 ( 3 ), T . int64 ( 16 ), T . int64 ( 16 ) ), "float32" ], # type: ignore
395+ T_layout_trans : T .Buffer [(T . int64 ( 1 ), T . int64 ( 1 ), T . int64 ( 16 ), T . int64 ( 16 ), T . int64 ( 3 ) ), "float32" ], # type: ignore
396396 ) -> None : # type: ignore
397397 # function attr dict
398398 T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
399399 # body
400400 # with T.block("root")
401- for i0 , i1 , i2 , i3 , i4 in T .grid (1 , 1 , 16 , 16 , 3 ):
401+ for i0 , i1 , i2 , i3 , i4 in T .grid (T . int64 ( 1 ), T . int64 ( 1 ), T . int64 ( 16 ), T . int64 ( 16 ), T . int64 ( 3 ) ):
402402 with T .block ("T_layout_trans" ):
403403 ax0 , ax1 , ax2 , ax3 , ax4 = T .axis .remap ("SSSSS" , [i0 , i1 , i2 , i3 , i4 ])
404- T .reads (placeholder [ax0 , ax1 * 3 + ax4 , ax2 , ax3 ])
404+ T .reads (placeholder [ax0 , ax1 * T . int64 ( 3 ) + ax4 , ax2 , ax3 ])
405405 T .writes (T_layout_trans [ax0 , ax1 , ax2 , ax3 , ax4 ])
406406 T_layout_trans [ax0 , ax1 , ax2 , ax3 , ax4 ] = T .if_then_else (
407- ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16 , # type: ignore
408- placeholder [ax0 , ax1 * 3 + ax4 , ax2 , ax3 ],
407+ ax0 < T . int64 ( 1 ) and ax1 * T . int64 ( 3 ) + ax4 < T . int64 ( 3 ) and ax2 < T . int64 ( 16 ) and ax3 < T . int64 ( 16 ) , # type: ignore
408+ placeholder [ax0 , ax1 * T . int64 ( 3 ) + ax4 , ax2 , ax3 ],
409409 T .float32 (0 ),
410410 dtype = "float32" ,
411411 )
412412
413413 @tvm .script .ir_module
414414 class _fused_layout_transform_1 :
415415 @T .prim_func
416- def main (placeholder : T .Buffer [(1 , 2 , 16 , 16 , 4 ) , "float32" ], T_layout_trans : T .Buffer [(1 , 8 , 16 , 16 ), "float32" ]) -> None : # type: ignore
416+ def main (placeholder : T .Buffer [(T . int64 ( 1 ), T . int64 ( 2 ), T . int64 ( 16 ), T . int64 ( 16 ), T . int64 ( 4 )) , "float32" ], T_layout_trans : T .Buffer [(T . int64 ( 1 ), T . int64 ( 8 ), T . int64 ( 16 ), T . int64 ( 16 ) ), "float32" ]) -> None : # type: ignore
417417 # function attr dict
418418 T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
419419 # body
420420 # with T.block("root")
421- for i0 , i1 , i2 , i3 in T .grid (1 , 8 , 16 , 16 ):
421+ for i0 , i1 , i2 , i3 in T .grid (T . int64 ( 1 ), T . int64 ( 8 ), T . int64 ( 16 ), T . int64 ( 16 ) ):
422422 with T .block ("T_layout_trans" ):
423423 ax0 , ax1 , ax2 , ax3 = T .axis .remap ("SSSS" , [i0 , i1 , i2 , i3 ])
424- T .reads (placeholder [ax0 , ax1 // 4 , ax2 , ax3 , ax1 % 4 ]) # type: ignore
424+ T .reads (placeholder [ax0 , ax1 // T . int64 ( 4 ) , ax2 , ax3 , ax1 % T . int64 ( 4 ) ]) # type: ignore
425425 T .writes (T_layout_trans [ax0 , ax1 , ax2 , ax3 ])
426- T_layout_trans [ax0 , ax1 , ax2 , ax3 ] = T .if_then_else (ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16 , placeholder [ax0 , ax1 // 4 , ax2 , ax3 , ax1 % 4 ], T .float32 (0 ), dtype = "float32" ) # type: ignore
426+ T_layout_trans [ax0 , ax1 , ax2 , ax3 ] = T .if_then_else (ax0 < T . int64 ( 1 ) and ax1 < T . int64 ( 8 ) and ax2 < T . int64 ( 16 ) and ax3 < T . int64 ( 16 ) , placeholder [ax0 , ax1 // T . int64 ( 4 ) , ax2 , ax3 , ax1 % T . int64 ( 4 ) ], T .float32 (0 ), dtype = "float32" ) # type: ignore
427427
428428 @tvm .script .ir_module
429429 class _fused_nn_contrib_conv2d_NCHWc :
430430 @T .prim_func
431- def main (placeholder : T .Buffer [(1 , 1 , 16 , 16 , 3 ) , "float32" ], placeholder_1 : T .Buffer [(2 , 1 , 5 , 5 , 3 , 4 ) , "float32" ], conv2d_NCHWc : T .Buffer [(1 , 2 , 16 , 16 , 4 ), "float32" ]) -> None : # type: ignore
431+ def main (placeholder : T .Buffer [(T . int64 ( 1 ), T . int64 ( 1 ), T . int64 ( 16 ), T . int64 ( 16 ), T . int64 ( 3 )) , "float32" ], placeholder_1 : T .Buffer [(T . int64 ( 2 ), T . int64 ( 1 ), T . int64 ( 5 ), T . int64 ( 5 ), T . int64 ( 3 ), T . int64 ( 4 )) , "float32" ], conv2d_NCHWc : T .Buffer [(T . int64 ( 1 ), T . int64 ( 2 ), T . int64 ( 16 ), T . int64 ( 16 ), T . int64 ( 4 ) ), "float32" ]) -> None : # type: ignore
432432 # function attr dict
433433 T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
434434 # body
435435 # with T.block("root")
436- data_pad = T .alloc_buffer ([1 , 1 , 20 , 20 , 3 ], dtype = "float32" )
437- for i0 , i1 , i2 , i3 , i4 in T .grid (1 , 1 , 20 , 20 , 3 ):
436+ data_pad = T .alloc_buffer ([T . int64 ( 1 ), T . int64 ( 1 ), T . int64 ( 20 ), T . int64 ( 20 ), T . int64 ( 3 ) ], dtype = "float32" )
437+ for i0 , i1 , i2 , i3 , i4 in T .grid (T . int64 ( 1 ), T . int64 ( 1 ), T . int64 ( 20 ), T . int64 ( 20 ), T . int64 ( 3 ) ):
438438 with T .block ("data_pad" ):
439439 i0_1 , i1_1 , i2_1 , i3_1 , i4_1 = T .axis .remap ("SSSSS" , [i0 , i1 , i2 , i3 , i4 ])
440- T .reads (placeholder [i0_1 , i1_1 , i2_1 - 2 , i3_1 - 2 , i4_1 ])
440+ T .reads (placeholder [i0_1 , i1_1 , i2_1 - T . int64 ( 2 ) , i3_1 - T . int64 ( 2 ) , i4_1 ])
441441 T .writes (data_pad [i0_1 , i1_1 , i2_1 , i3_1 , i4_1 ])
442- data_pad [i0_1 , i1_1 , i2_1 , i3_1 , i4_1 ] = T .if_then_else (2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18 , placeholder [i0_1 , i1_1 , i2_1 - 2 , i3_1 - 2 , i4_1 ], T .float32 (0 ), dtype = "float32" ) # type: ignore # pylint: disable=R1716
443- for i0 , i1 , i2 , i3 , i4 , i5 , i6 , i7 in T .grid (1 , 2 , 16 , 16 , 4 , 3 , 5 , 5 ):
442+ data_pad [i0_1 , i1_1 , i2_1 , i3_1 , i4_1 ] = T .if_then_else (T . int64 ( 2 ) <= i2_1 and i2_1 < T . int64 ( 18 ) and T . int64 ( 2 ) <= i3_1 and i3_1 < T . int64 ( 18 ) , placeholder [i0_1 , i1_1 , i2_1 - T . int64 ( 2 ) , i3_1 - T . int64 ( 2 ) , i4_1 ], T .float32 (0 ), dtype = "float32" ) # type: ignore # pylint: disable=R1716
443+ for i0 , i1 , i2 , i3 , i4 , i5 , i6 , i7 in T .grid (T . int64 ( 1 ), T . int64 ( 2 ), T . int64 ( 16 ), T . int64 ( 16 ), T . int64 ( 4 ), T . int64 ( 3 ), T . int64 ( 5 ), T . int64 ( 5 ) ):
444444 with T .block ("conv2d_NCHWc" ):
445445 n , oc_chunk , oh , ow , oc_block , ic , kh , kw = T .axis .remap ("SSSSSRRR" , [i0 , i1 , i2 , i3 , i4 , i5 , i6 , i7 ])
446- T .reads (data_pad [n , ic // 3 , oh + kh , ow + kw , ic % 3 ], placeholder_1 [oc_chunk , ic // 3 , kh , kw , ic % 3 , oc_block ]) # type: ignore
446+ T .reads (data_pad [n , ic // T . int64 ( 3 ) , oh + kh , ow + kw , ic % T . int64 ( 3 ) ], placeholder_1 [oc_chunk , ic // T . int64 ( 3 ) , kh , kw , ic % T . int64 ( 3 ) , oc_block ]) # type: ignore
447447 T .writes (conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ])
448448 with T .init ():
449449 conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] = T .float32 (0 )
450- conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] = conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] + data_pad [n , ic // 3 , oh + kh , ow + kw , ic % 3 ] * placeholder_1 [oc_chunk , ic // 3 , kh , kw , ic % 3 , oc_block ] # type: ignore
450+ conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] = conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] + data_pad [n , ic // T . int64 ( 3 ) , oh + kh , ow + kw , ic % T . int64 ( 3 ) ] * placeholder_1 [oc_chunk , ic // T . int64 ( 3 ) , kh , kw , ic % T . int64 ( 3 ) , oc_block ] # type: ignore
451451
452452 # fmt: on
453453 # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
0 commit comments