Skip to content

Commit 6d77ec9

Browse files
committed
lint fixes
1 parent 7488ea1 commit 6d77ec9

File tree

1 file changed

+9
-27
lines changed

1 file changed

+9
-27
lines changed

tests/python/relax/test_dataflow_rewriter.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,8 @@ def before(x: R.Tensor([16], "float32")):
257257

258258
@R.function(private=True)
259259
def expected(x: R.Tensor([16], "float32")):
260-
y = R.call_pure_packed(
261-
"my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")
262-
)
263-
z = R.call_pure_packed(
264-
"my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32")
265-
)
260+
y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32"))
261+
z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32"))
266262
return z
267263

268264
after = Rewriter(before)
@@ -316,12 +312,8 @@ def expected(
316312
B: R.Tensor([16], "float32"),
317313
C: R.Tensor([16], "float32"),
318314
):
319-
D = R.call_pure_packed(
320-
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
321-
)
322-
E = R.call_pure_packed(
323-
"my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32")
324-
)
315+
D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32"))
316+
E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32"))
325317
return E
326318

327319
rewriter = RewriteAdd | RewriteMultiply
@@ -457,9 +449,7 @@ def pattern(A: R.Tensor([16], "float32")):
457449

458450
@R.function
459451
def replacement(A: R.Tensor([16], "float32")):
460-
return R.call_tir(
461-
RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32")
462-
)
452+
return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32"))
463453

464454
@T.prim_func(private=True)
465455
def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
@@ -537,9 +527,7 @@ def pattern(A: R.Tensor([16], "float32")):
537527

538528
@R.function
539529
def replacement(A: R.Tensor([16], "float32")):
540-
return R.call_tir(
541-
RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32")
542-
)
530+
return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32"))
543531

544532
@T.prim_func(private=True)
545533
def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
@@ -559,9 +547,7 @@ class Expected:
559547
@R.function
560548
def main(A: R.Tensor([16], "float32")):
561549
B = Expected.subroutine(A)
562-
C = R.call_tir(
563-
Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32")
564-
)
550+
C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32"))
565551
return C
566552

567553
@R.function(private=True)
@@ -1212,9 +1198,7 @@ def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
12121198
)
12131199

12141200
@R.function(private=True)
1215-
def before(
1216-
A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")
1217-
):
1201+
def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")):
12181202
if cond:
12191203
out = A + B
12201204
else:
@@ -1223,9 +1207,7 @@ def before(
12231207
return out
12241208

12251209
@R.function(private=True)
1226-
def expected(
1227-
A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")
1228-
):
1210+
def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")):
12291211
if cond:
12301212
out = R.call_pure_packed(
12311213
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")

0 commit comments

Comments
 (0)