@@ -257,12 +257,8 @@ def before(x: R.Tensor([16], "float32")):
257257
258258 @R .function (private = True )
259259 def expected (x : R .Tensor ([16 ], "float32" )):
260- y = R .call_pure_packed (
261- "my_optimized_add_impl" , x , x , sinfo_args = R .Tensor ([16 ], "float32" )
262- )
263- z = R .call_pure_packed (
264- "my_optimized_add_impl" , y , y , sinfo_args = R .Tensor ([16 ], "float32" )
265- )
260+ y = R .call_pure_packed ("my_optimized_add_impl" , x , x , sinfo_args = R .Tensor ([16 ], "float32" ))
261+ z = R .call_pure_packed ("my_optimized_add_impl" , y , y , sinfo_args = R .Tensor ([16 ], "float32" ))
266262 return z
267263
268264 after = Rewriter (before )
@@ -316,12 +312,8 @@ def expected(
316312 B : R .Tensor ([16 ], "float32" ),
317313 C : R .Tensor ([16 ], "float32" ),
318314 ):
319- D = R .call_pure_packed (
320- "my_optimized_add_impl" , A , B , sinfo_args = R .Tensor ([16 ], "float32" )
321- )
322- E = R .call_pure_packed (
323- "my_optimized_mul_impl" , C , D , sinfo_args = R .Tensor ([16 ], "float32" )
324- )
315+ D = R .call_pure_packed ("my_optimized_add_impl" , A , B , sinfo_args = R .Tensor ([16 ], "float32" ))
316+ E = R .call_pure_packed ("my_optimized_mul_impl" , C , D , sinfo_args = R .Tensor ([16 ], "float32" ))
325317 return E
326318
327319 rewriter = RewriteAdd | RewriteMultiply
@@ -457,9 +449,7 @@ def pattern(A: R.Tensor([16], "float32")):
457449
458450 @R .function
459451 def replacement (A : R .Tensor ([16 ], "float32" )):
460- return R .call_tir (
461- RewriteMul .subroutine_mul , [A ], out_sinfo = R .Tensor ([16 ], "float32" )
462- )
452+ return R .call_tir (RewriteMul .subroutine_mul , [A ], out_sinfo = R .Tensor ([16 ], "float32" ))
463453
464454 @T .prim_func (private = True )
465455 def subroutine_mul (A : T .Buffer (16 , "float32" ), B : T .Buffer (16 , "float32" )):
@@ -537,9 +527,7 @@ def pattern(A: R.Tensor([16], "float32")):
537527
538528 @R .function
539529 def replacement (A : R .Tensor ([16 ], "float32" )):
540- return R .call_tir (
541- RewriteMul .subroutine , [A ], out_sinfo = R .Tensor ([16 ], "float32" )
542- )
530+ return R .call_tir (RewriteMul .subroutine , [A ], out_sinfo = R .Tensor ([16 ], "float32" ))
543531
544532 @T .prim_func (private = True )
545533 def subroutine (A : T .Buffer (16 , "float32" ), B : T .Buffer (16 , "float32" )):
@@ -559,9 +547,7 @@ class Expected:
559547 @R .function
560548 def main (A : R .Tensor ([16 ], "float32" )):
561549 B = Expected .subroutine (A )
562- C = R .call_tir (
563- Expected .subroutine_1 , [B ], out_sinfo = R .Tensor ([16 ], "float32" )
564- )
550+ C = R .call_tir (Expected .subroutine_1 , [B ], out_sinfo = R .Tensor ([16 ], "float32" ))
565551 return C
566552
567553 @R .function (private = True )
@@ -1212,9 +1198,7 @@ def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
12121198 )
12131199
12141200 @R .function (private = True )
1215- def before (
1216- A : R .Tensor ([16 ], "float32" ), B : R .Tensor ([16 ], "float32" ), cond : R .Prim ("bool" )
1217- ):
1201+ def before (A : R .Tensor ([16 ], "float32" ), B : R .Tensor ([16 ], "float32" ), cond : R .Prim ("bool" )):
12181202 if cond :
12191203 out = A + B
12201204 else :
@@ -1223,9 +1207,7 @@ def before(
12231207 return out
12241208
12251209 @R .function (private = True )
1226- def expected (
1227- A : R .Tensor ([16 ], "float32" ), B : R .Tensor ([16 ], "float32" ), cond : R .Prim ("bool" )
1228- ):
1210+ def expected (A : R .Tensor ([16 ], "float32" ), B : R .Tensor ([16 ], "float32" ), cond : R .Prim ("bool" )):
12291211 if cond :
12301212 out = R .call_pure_packed (
12311213 "my_optimized_add_impl" , A , B , sinfo_args = R .Tensor ([16 ], "float32" )
0 commit comments