Skip to content

Commit 7062789

Browse files
committed
Removing more usage of preflattened from python files
1 parent f1579c7 commit 7062789

File tree

4 files changed

+10
-124
lines changed

4 files changed

+10
-124
lines changed

tests/python/unittest/test_aot_legalize_packed_call.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,12 @@
2626
class Module:
2727
@T.prim_func
2828
def tvm_test_cpacked(
29-
A: T.handle, B: T.handle, C: T.handle, device_context: T.handle
29+
A: T.Buffer[(1,), "float32"],
30+
B: T.Buffer[(1,), "float32"],
31+
C: T.Buffer[(1,), "float32"],
32+
device_context: T.Buffer[(1,), "float32"],
3033
) -> T.handle:
31-
A_0 = T.match_buffer(A, (1,), dtype="float32")
32-
A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32")
33-
B_0 = T.match_buffer(B, (1,), dtype="float32")
34-
B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32")
35-
C_0 = T.match_buffer(C, (1,), dtype="float32")
36-
C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32")
37-
T.evaluate(C)
34+
T.evaluate(C.data)
3835

3936
@T.prim_func
4037
def tir_packed_call() -> None:
@@ -59,15 +56,12 @@ def tir_packed_call() -> None:
5956
class Expected:
6057
@T.prim_func
6158
def tvm_test_cpacked(
62-
A: T.handle, B: T.handle, C: T.handle, device_context: T.handle
59+
A: T.Buffer[(1,), "float32"],
60+
B: T.Buffer[(1,), "float32"],
61+
C: T.Buffer[(1,), "float32"],
62+
device_context: T.handle,
6363
) -> T.handle:
64-
A_0 = T.match_buffer(A, (1,), dtype="float32")
65-
A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32")
66-
B_0 = T.match_buffer(B, (1,), dtype="float32")
67-
B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32")
68-
C_0 = T.match_buffer(C, (1,), dtype="float32")
69-
C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32")
70-
T.evaluate(C)
64+
T.evaluate(C.data)
7165

7266
@T.prim_func
7367
def tir_packed_call() -> None:

0 commit comments

Comments
 (0)