Skip to content

Commit d9c2b9c

Browse files
sisleyliBin Li
andauthored
[TIR][BugFix]Ensure the Var's scope is correct (#15406)
* [TIR][BugFix]Ensure Var's scope is correct * add new testcases * fix lint --------- Co-authored-by: Bin Li <[email protected]>
1 parent e2c8d7b commit d9c2b9c

File tree

2 files changed

+81
-3
lines changed

2 files changed

+81
-3
lines changed

src/tir/transforms/unsupported_dtype_legalize.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ class ComputeLegalizePlanner : public StmtExprVisitor {
7979
// remap all intermediate constant buffer to promote data types (fp16/fp32)
8080
if (MatchDType(op->dtype) && op->ConstantAllocationSize() != 0) {
8181
DataType dtype = promote_dtype_.with_lanes(op->dtype.lanes());
82-
Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
82+
String storage_scope = "global";
83+
if (auto* ptr_type = op->buffer_var->type_annotation.as<PointerTypeNode>()) {
84+
storage_scope = ptr_type->storage_scope;
85+
}
86+
Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype), storage_scope));
8387
(*var_remap_)[op->buffer_var] = buffer_var;
8488
}
8589
return StmtExprVisitor::VisitStmt_(op);
@@ -496,7 +500,11 @@ class StorageLegalizer : public StmtExprMutator {
496500
Stmt VisitStmt_(const AllocateNode* op) final {
497501
if (MatchDType(op->dtype)) {
498502
DataType dtype = GetStorageUIntDType(op->dtype);
499-
Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
503+
String storage_scope = "global";
504+
if (auto* ptr_type = op->buffer_var->type_annotation.as<PointerTypeNode>()) {
505+
storage_scope = ptr_type->storage_scope;
506+
}
507+
Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype), storage_scope));
500508
var_remap_[op->buffer_var] = buffer_var;
501509
return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body));
502510
} else {
@@ -637,7 +645,8 @@ class StorageLegalizer : public StmtExprMutator {
637645
if (auto* elem_type = ptr_type->element_type.as<PrimTypeNode>()) {
638646
if (MatchDType(elem_type->dtype)) {
639647
Var new_var =
640-
Var(var->name_hint, PointerType(PrimType(GetStorageUIntDType(elem_type->dtype))));
648+
Var(var->name_hint, PointerType(PrimType(GetStorageUIntDType(elem_type->dtype)),
649+
ptr_type->storage_scope));
641650
var_remap_[var] = new_var;
642651
return new_var;
643652
}

tests/python/unittest/test_tir_transform_bf16_legalize.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,74 @@ def test_bf16_storage_legalize():
114114
tvm.ir.assert_structural_equal(after, expected)
115115

116116

117+
def test_bf16_storage_scope():
118+
def get_before():
119+
@tvm.script.ir_module
120+
class Before:
121+
@T.prim_func
122+
def main(
123+
Aptr: T.handle("bfloat16", storage_scope="shared"),
124+
Bptr: T.handle("bfloat16", storage_scope="local"),
125+
Dptr: T.handle("bfloat16"),
126+
):
127+
T.func_attr({"global_symbol": "main"})
128+
A = T.decl_buffer((100,), "bfloat16", data=Aptr)
129+
B = T.decl_buffer((100,), "bfloat16", data=Bptr)
130+
D = T.decl_buffer((100,), "bfloat16", data=Dptr)
131+
C = T.decl_buffer((100,), "bfloat16")
132+
for i in T.grid(100):
133+
C[i] = A[i] + B[i]
134+
D[i] = T.exp(C[i])
135+
136+
return Before
137+
138+
def after_compute_legalize():
139+
@tvm.script.ir_module
140+
class After:
141+
@T.prim_func
142+
def main(
143+
Aptr: T.handle("bfloat16", storage_scope="shared"),
144+
Bptr: T.handle("bfloat16", storage_scope="local"),
145+
Dptr: T.handle("bfloat16"),
146+
):
147+
T.func_attr({"global_symbol": "main"})
148+
A = T.decl_buffer((100,), "bfloat16", data=Aptr)
149+
B = T.decl_buffer((100,), "bfloat16", data=Bptr)
150+
D = T.decl_buffer((100,), "bfloat16", data=Dptr)
151+
C = T.decl_buffer((100,), "float32")
152+
for i in T.grid(100):
153+
C[i] = bf16tof32(A[i]) + bf16tof32(B[i])
154+
D[i] = f32tobf16(T.exp(C[i]))
155+
156+
return After
157+
158+
def after_storage_legalize():
159+
@tvm.script.ir_module
160+
class After:
161+
@T.prim_func
162+
def main(
163+
Aptr: T.handle("uint16", storage_scope="shared"),
164+
Bptr: T.handle("uint16", storage_scope="local"),
165+
Dptr: T.handle("uint16"),
166+
):
167+
T.func_attr({"global_symbol": "main"})
168+
A = T.decl_buffer((100,), "uint16", data=Aptr)
169+
B = T.decl_buffer((100,), "uint16", data=Bptr)
170+
D = T.decl_buffer((100,), "uint16", data=Dptr)
171+
C = T.decl_buffer((100,), "float32")
172+
for i in T.grid(100):
173+
C[i] = u16tof32(A[i]) + u16tof32(B[i])
174+
D[i] = f32tou16(T.exp(C[i]))
175+
176+
return After
177+
178+
before = get_before()
179+
after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
180+
after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
181+
tvm.ir.assert_structural_equal(after_compute, after_compute_legalize())
182+
tvm.ir.assert_structural_equal(after_storage, after_storage_legalize())
183+
184+
117185
if __name__ == "__main__":
118186
test_bf16_storage_legalize()
187+
test_bf16_storage_scope()

0 commit comments

Comments
 (0)