Skip to content

Commit 9191ecd

Browse files
committed
Fix some failing tests
* vload/vstore updates that were missed previously * int1 -> bool updates * fix gpu target tests Fixes a test and updates comments referencing old load/store api Change-Id: I26a0c480d2dedee442ca0116909a7751d1dfa9ac
1 parent ae89b1e commit 9191ecd

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

src/script/printer/tir/buffer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
275275
ExprDoc buffer = d->AsDoc<ExprDoc>(store->buffer, p->Attr("buffer"));
276276
ExprDoc value = d->AsDoc<ExprDoc>(store->value, p->Attr("value"));
277277

278-
// Use .store(...) syntax when there is a predicate
278+
// Use .vstore(...) syntax when there is a predicate
279279
if (store->predicate.defined()) {
280280
ExprDoc indices = d->AsDoc<ExprDoc>(store->indices, p->Attr("indices"));
281281
ExprDoc predicate = d->AsDoc<ExprDoc>(store->predicate, p->Attr("predicate"));
@@ -293,7 +293,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
293293
"", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc {
294294
ExprDoc buffer = d->AsDoc<ExprDoc>(load->buffer, p->Attr("buffer"));
295295

296-
// Use .load(...) syntax when there is a predicate
296+
// Use .vload(...) syntax when there is a predicate
297297
if (load->predicate.defined()) {
298298
ExprDoc indices = d->AsDoc<ExprDoc>(load->indices, p->Attr("indices"));
299299
ExprDoc predicate = d->AsDoc<ExprDoc>(load->predicate, p->Attr("predicate"));

tests/python/codegen/test_target_codegen.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm.script import tir as T
2222

2323

24-
@tvm.testing.exclude_targets("llvm")
24+
@tvm.testing.parametrize_targets("c")
2525
def test_buffer_store_predicate_not_supported(target):
2626
@T.prim_func
2727
def func(b: T.handle):
@@ -34,7 +34,25 @@ def func(b: T.handle):
3434
tvm.build(func)
3535

3636

37-
@tvm.testing.exclude_targets("llvm")
37+
@tvm.testing.parametrize_targets("cuda", "opencl", "metal", "rocm", "vulkan -from_device=0")
38+
def test_buffer_store_predicate_not_supported_gpu(target):
39+
@T.prim_func
40+
def func(a: T.handle, b: T.handle):
41+
A = T.match_buffer(a, (2, 3), "float32")
42+
B = T.match_buffer(b, (6,), "float32")
43+
T.func_attr({"global_symbol": "main"})
44+
for i_0 in T.thread_binding(3, thread="threadIdx.x"):
45+
B.vstore(
46+
[T.Ramp(i_0, 1, 4)], T.Broadcast(1.0, 4), predicate=T.Broadcast(T.bool(True), 4)
47+
)
48+
49+
err_msg = "Predicated buffer store is not supported."
50+
with pytest.raises(tvm.TVMError, match=err_msg):
51+
with tvm.target.Target(target):
52+
tvm.build(func)
53+
54+
55+
@tvm.testing.parametrize_targets("c")
3856
def test_buffer_load_predicate_not_supported(target):
3957
@T.prim_func
4058
def func(a: T.handle, b: T.handle):
@@ -52,5 +70,23 @@ def func(a: T.handle, b: T.handle):
5270
tvm.build(func)
5371

5472

73+
@tvm.testing.parametrize_targets("cuda", "opencl", "metal", "rocm", "vulkan -from_device=0")
74+
def test_buffer_load_predicate_not_supported_gpu(target):
75+
@T.prim_func
76+
def func(a: T.handle, b: T.handle):
77+
A = T.match_buffer(a, (8,), "float32")
78+
B = T.match_buffer(b, (8,), "float32")
79+
for i_0 in T.thread_binding(3, thread="threadIdx.x"):
80+
B.vstore(
81+
[T.Ramp(0, 2, 4)],
82+
A.vload([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)),
83+
)
84+
85+
err_msg = "Predicated buffer load is not supported."
86+
with pytest.raises(tvm.TVMError, match=err_msg):
87+
with tvm.target.Target(target):
88+
tvm.build(func)
89+
90+
5591
if __name__ == "__main__":
5692
tvm.testing.main()

tests/python/tir-transform/test_tir_transform_vectorize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,10 @@ def main(a: T.handle, n: T.int32, x: T.int32):
186186
T.float32(1), extent
187187
)
188188
else:
189-
A.store(
190-
T.Broadcast(T.float32(2), T.vscale() * 4),
189+
A.vstore(
191190
[T.Ramp(0, 1, T.vscale() * 4)],
192-
predicate=T.get_active_lane_mask("int1xvscalex4", 0, n),
191+
T.Broadcast(T.float32(2), T.vscale() * 4),
192+
predicate=T.get_active_lane_mask("uint1xvscalex4", 0, n),
193193
)
194194

195195
with tvm.target.Target(target):

tests/python/tvmscript/test_tvmscript_ir_builder_tir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def test_ir_builder_tir_buffer_store_predicate():
472472
buffer_a = T.Buffer((30,), "float32")
473473
value = T.broadcast(0.11, T.vscale() * 4)
474474
index = T.ramp(0, 1, T.vscale() * 4)
475-
predicate = T.broadcast(1, T.vscale() * 4)
475+
predicate = T.broadcast(T.bool(True), T.vscale() * 4)
476476

477477
with IRBuilder() as ib:
478478
T.buffer_store(buffer_a, value, [index], predicate)

0 commit comments

Comments
 (0)