@@ -611,6 +611,60 @@ def elementwise_overcomputed_producer_reverse_inlined(
611611 C [vi , vj ] = A [vi , vj ] * 2.0 + 1.0
612612
613613
614+ @T .prim_func
615+ def elementwise_overcomputed_producer_simplify_predicate (
616+ A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((127 , 127 ), "float32" )
617+ ) -> None :
618+ B = T .alloc_buffer ((128 , 128 ))
619+ for i in T .grid (16384 ):
620+ with T .block ("B" ):
621+ vi = T .axis .spatial (128 , i // 128 )
622+ vj = T .axis .spatial (128 , i % 128 )
623+ B [vi , vj ] = A [vi , vj ] * 2.0
624+ for i , j in T .grid (127 , 127 ):
625+ with T .block ("C" ):
626+ cvi , cvj = T .axis .remap ("SS" , [i , j ])
627+ C [cvi , cvj ] = B [cvi , cvj ] + 1.0
628+
629+
630+ @T .prim_func
631+ def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined (
632+ A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((127 , 127 ), "float32" )
633+ ) -> None :
634+ for i in T .grid (16384 ):
635+ with T .block ("B" ):
636+ vi = T .axis .spatial (128 , i // 128 )
637+ vj = T .axis .spatial (128 , i % 128 )
638+ T .where (i < 16255 and i % 128 < 127 )
639+ C [vi , vj ] = A [vi , vj ] * 2.0 + 1.0
640+
641+
642+ @T .prim_func
643+ def elementwise_overcomputed_producer_injective_load (
644+ A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((127 , 127 ), "float32" )
645+ ) -> None :
646+ B = T .alloc_buffer ((8 , 8 , 16 , 16 ))
647+ for i0 , j0 , i1 , j1 in T .grid (8 , 8 , 16 , 16 ):
648+ with T .block ("B" ):
649+ vi , vj , vm , vn = T .axis .remap ("SSSS" , [i0 , j0 , i1 , j1 ])
650+ B [vi , vj , vm , vn ] = A [vi * 16 + vm , vj * 16 + vn ] * 2.0
651+ for i , j in T .grid (127 , 127 ):
652+ with T .block ("C" ):
653+ cvi , cvj = T .axis .remap ("SS" , [i , j ])
654+ C [cvi , cvj ] = B [cvi // 16 , cvj // 16 , cvi % 16 , cvj % 16 ] + 1.0
655+
656+
657+ @T .prim_func
658+ def elementwise_overcomputed_producer_injective_load_reverse_inlined (
659+ A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((127 , 127 ), "float32" )
660+ ) -> None :
661+ for i0 , j0 , i1 , j1 in T .grid (8 , 8 , 16 , 16 ):
662+ with T .block ("B" ):
663+ vi , vj , vm , vn = T .axis .remap ("SSSS" , [i0 , j0 , i1 , j1 ])
664+ T .where (i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127 )
665+ C [vm + vi * 16 , vn + vj * 16 ] = A [vi * 16 + vm , vj * 16 + vn ] * 2.0 + 1.0
666+
667+
614668@T .prim_func
615669def elementwise_producer_not_cover_consumer (
616670 A : T .Buffer ((128 , 128 ), "float32" ), D : T .Buffer ((256 , 128 ), "float32" )
@@ -1025,6 +1079,26 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name):
10251079 )
10261080
10271081
1082+ def test_reverse_compute_inline_overcomputed_producer_simplify_predicate (use_block_name ):
1083+ """Test reverse compute inline overcomputed producer where the predicate should be simplified"""
1084+ sch = tir .Schedule (elementwise_overcomputed_producer_simplify_predicate , debug_mask = "all" )
1085+ compute = "C" if use_block_name else sch .get_block ("C" )
1086+ sch .reverse_compute_inline (compute )
1087+ tvm .ir .assert_structural_equal (
1088+ elementwise_overcomputed_producer_simplify_predicate_reverse_inlined , sch .mod ["main" ]
1089+ )
1090+
1091+
1092+ def test_reverse_compute_inline_overcomputed_producer_injective_load (use_block_name ):
1093+ """Test reverse compute inline overcomputed producer with injective buffer load"""
1094+ sch = tir .Schedule (elementwise_overcomputed_producer_injective_load , debug_mask = "all" )
1095+ compute = "C" if use_block_name else sch .get_block ("C" )
1096+ sch .reverse_compute_inline (compute )
1097+ tvm .ir .assert_structural_equal (
1098+ elementwise_overcomputed_producer_injective_load_reverse_inlined , sch .mod ["main" ]
1099+ )
1100+
1101+
10281102def test_reverse_compute_inline_error_producer_not_cover_consumer (use_block_name ):
10291103 """Test reverse compute inline failure when the inlined block iter domains are not covered by
10301104 its producer
0 commit comments