@@ -472,5 +472,105 @@ def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"),
472472 )
473473
474474
475+ def test_reshape ():
476+ @I .ir_module
477+ class Before :
478+ @T .prim_func (private = True )
479+ def reshape (
480+ A : T .Buffer ((T .int64 (850 ), T .int64 (2048 )), "float16" ),
481+ T_reshape : T .Buffer ((T .int64 (850 ), T .int64 (1 ), T .int64 (2048 )), "float16" ),
482+ ):
483+ T .func_attr ({"operator_name" : "relax.reshape" })
484+ for ax0 , ax1 , ax2 in T .grid (T .int64 (850 ), T .int64 (1 ), T .int64 (2048 )):
485+ with T .block ("T_reshape" ):
486+ v_ax0 , v_ax1 , v_ax2 = T .axis .remap ("SSS" , [ax0 , ax1 , ax2 ])
487+ T .reads (
488+ A [
489+ (v_ax2 // T .int64 (2048 ) + v_ax0 + v_ax1 ) % T .int64 (850 ),
490+ v_ax2 % T .int64 (2048 ),
491+ ]
492+ )
493+ T .writes (T_reshape [v_ax0 , v_ax1 , v_ax2 ])
494+ T_reshape [v_ax0 , v_ax1 , v_ax2 ] = A [
495+ (v_ax2 // T .int64 (2048 ) + v_ax0 + v_ax1 ) % T .int64 (850 ),
496+ v_ax2 % T .int64 (2048 ),
497+ ]
498+
499+ @R .function
500+ def main (
501+ x : R .Tensor ((850 , 2048 ), dtype = "float16" )
502+ ) -> R .Tensor ((850 , 1 , 2048 ), dtype = "float16" ):
503+ cls = Before
504+ with R .dataflow ():
505+ lv = R .call_tir (
506+ cls .reshape , (x ,), out_sinfo = R .Tensor ((850 , 1 , 2048 ), dtype = "float16" )
507+ )
508+ gv : R .Tensor ((850 , 1 , 2048 ), dtype = "float16" ) = lv
509+ R .output (gv )
510+ return gv
511+
512+ @I .ir_module
513+ class Expected :
514+ @T .prim_func (private = True )
515+ def relax_reshape_replacement (
516+ A : T .Buffer ((T .int64 (850 ), T .int64 (2 ), T .int64 (1024 )), "float16" ),
517+ T_reshape : T .Buffer ((T .int64 (850 ), T .int64 (1 ), T .int64 (2048 )), "float16" ),
518+ ):
519+ T .func_attr ({"operator_name" : "relax.reshape" })
520+ for ax0 , ax1 , ax2 in T .grid (T .int64 (850 ), T .int64 (1 ), T .int64 (2048 )):
521+ with T .block ("T_reshape" ):
522+ v_ax0 , v_ax1 , v_ax2 = T .axis .remap ("SSS" , [ax0 , ax1 , ax2 ])
523+ T .reads (A [v_ax0 , v_ax2 // T .int64 (1024 ), v_ax2 % T .int64 (1024 )])
524+ T .writes (T_reshape [v_ax0 , v_ax1 , v_ax2 ])
525+ T_reshape [v_ax0 , v_ax1 , v_ax2 ] = A [
526+ v_ax0 , v_ax2 // T .int64 (1024 ), v_ax2 % T .int64 (1024 )
527+ ]
528+
529+ @R .function
530+ def main (
531+ x : R .Tensor ((850 , 2048 ), dtype = "float16" )
532+ ) -> R .Tensor ((850 , 1 , 2048 ), dtype = "float16" ):
533+ cls = Expected
534+ with R .dataflow ():
535+ lv : R .Tensor ((850 , 2 , 1024 ), dtype = "float16" ) = R .layout_transform (
536+ x ,
537+ index_map = T .index_map (lambda i , j : (i , j // 1024 , j % 1024 )),
538+ pad_value = None ,
539+ axis_separators = [],
540+ )
541+ lv_1 = R .call_tir (
542+ cls .relax_reshape_replacement ,
543+ (lv ,),
544+ out_sinfo = R .Tensor ((850 , 1 , 2048 ), dtype = "float16" ),
545+ )
546+ gv : R .Tensor ((850 , 1 , 2048 ), dtype = "float16" ) = lv_1
547+ R .output (gv )
548+ return gv
549+
550+ @T .prim_func (private = True )
551+ def reshape_new (
552+ A : T .Buffer ((T .int64 (850 ), T .int64 (2 ), T .int64 (1024 )), "float16" ),
553+ T_reshape : T .Buffer ((T .int64 (850 ), T .int64 (1 ), T .int64 (2048 )), "float16" ),
554+ ):
555+ for ax0 , ax1 , ax2 in T .grid (T .int64 (850 ), T .int64 (1 ), T .int64 (2048 )):
556+ with T .block ("T_reshape" ):
557+ v_ax0 , v_ax1 , v_ax2 = T .axis .remap ("SSS" , [ax0 , ax1 , ax2 ])
558+ T .reads (A [v_ax0 , v_ax2 // T .int64 (1024 ), v_ax2 % T .int64 (1024 )])
559+ T .writes (T_reshape [v_ax0 , v_ax1 , v_ax2 ])
560+ T_reshape [v_ax0 , v_ax1 , v_ax2 ] = A [
561+ v_ax0 , v_ax2 // T .int64 (1024 ), v_ax2 % T .int64 (1024 )
562+ ]
563+
564+ # fmt: on
565+ index_map = lambda i , j : (i , j // 1024 , j % 1024 )
566+ _check (
567+ Before ,
568+ Expected ,
569+ operator_name = "relax.reshape" ,
570+ replacement_primfunc = reshape_new ,
571+ layout_changes = [index_map , None ],
572+ )
573+
574+
475575if __name__ == "__main__" :
476576 tvm .testing .main ()
0 commit comments