Skip to content

Commit a2ae70e

Browse files
committed
Merge branch 'tir_check_function_attr_for_device_type_pr_16727' into HEAD
2 parents b51541c + b362f86 commit a2ae70e

File tree

2 files changed

+60
-13
lines changed

2 files changed

+60
-13
lines changed

src/tir/transforms/lower_tvm_builtin.cc

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ namespace tir {
3838
// These information are needed during codegen.
3939
class BuiltinLower : public StmtExprMutator {
4040
public:
41+
static PrimFunc Build(PrimFunc func) {
42+
Optional<PrimExpr> device_type = NullOpt;
43+
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
44+
device_type = Integer(target.value()->kind->default_device_type);
45+
}
46+
47+
BuiltinLower mutator(device_type);
48+
func.CopyOnWrite()->body = mutator.VisitBodyAndRealizeAlloca(func->body);
49+
return func;
50+
}
51+
52+
BuiltinLower(Optional<PrimExpr> device_type = NullOpt) : device_type_(device_type) {}
53+
4154
// NOTE: Right now, we make the following scoping requirement
4255
// for memory allocated by the following primitives
4356
// - tvm_stack_make_array
@@ -284,13 +297,17 @@ class BuiltinLower : public StmtExprMutator {
284297

285298
Stmt VisitStmt_(const AttrStmtNode* op) final {
286299
if (op->attr_key == attr::device_id) {
287-
ICHECK(!device_id_);
300+
auto cache = device_id_;
288301
device_id_ = op->value;
289-
return this->VisitStmt(op->body);
302+
Stmt out = this->VisitStmt(op->body);
303+
device_id_ = cache;
304+
return out;
290305
} else if (op->attr_key == attr::device_type) {
291-
ICHECK(!device_type_);
306+
auto cache = device_type_;
292307
device_type_ = op->value;
293-
return this->VisitStmt(op->body);
308+
Stmt out = this->VisitStmt(op->body);
309+
device_type_ = cache;
310+
return out;
294311
} else {
295312
return StmtExprMutator::VisitStmt_(op);
296313
}
@@ -656,13 +673,12 @@ class BuiltinLower : public StmtExprMutator {
656673
namespace transform {
657674

658675
Pass LowerTVMBuiltin() {
659-
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
660-
if (IsHostFunc(f).value_or(false)) {
661-
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
662-
f.CopyOnWrite()->body = BuiltinLower().Build(f->body);
663-
VLOG(2) << "LowerTVMBuiltin: " << f;
676+
auto pass_func = [](PrimFunc func, IRModule m, PassContext ctx) {
677+
if (IsHostFunc(func).value_or(false)) {
678+
func = BuiltinLower::Build(func);
679+
VLOG(2) << "LowerTVMBuiltin: " << func;
664680
}
665-
return f;
681+
return func;
666682
};
667683
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
668684
}

tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,13 @@ def expected():
260260

261261

262262
class TestLowerAllocateRequiresDeviceID(tvm.testing.CompareBeforeAfter):
263+
"""If device id is missing, error."""
264+
263265
transform = tvm.tir.transform.LowerTVMBuiltin()
264266

265267
def before():
266268
T.func_attr({"target": T.target("llvm")})
267-
T.attr("dummy", "device_id", 0)
269+
T.attr("dummy", "device_type", 2) # kDLCuda
268270
ptr = T.allocate([16], "float32")
269271
buf = T.decl_buffer(16, "float32", data=ptr)
270272
buf[0] = 0.0
@@ -273,16 +275,45 @@ def before():
273275

274276

275277
class TestLowerAllocateRequiresDeviceType(tvm.testing.CompareBeforeAfter):
278+
"""If device type is missing, error.
279+
280+
The device type can be inferred either from the `"device_type"`
281+
statement attribute, or from the `"target"` function attribute.
282+
Here, we provide neither. The `"tir.is_host_func"` attribute is
283+
provided as otherwise the function would be skipped altogether by
284+
LowerTVMBuiltin.
285+
"""
286+
276287
transform = tvm.tir.transform.LowerTVMBuiltin()
277288

278289
def before():
279-
T.func_attr({"target": T.target("llvm")})
290+
T.func_attr({"tir.is_host_func": True})
280291
T.attr("dummy", "device_id", 0)
292+
ptr = T.allocate([1024 * 1024], "float32")
293+
buf = T.decl_buffer(1024 * 1024, "float32", data=ptr)
294+
buf[0] = 0.0
295+
296+
expected = tvm.TVMError
297+
298+
299+
class TestLowerCPUAllocWithFunctionAttr(tvm.testing.CompareBeforeAfter):
300+
"""CPU allocations can be handled at codegen time
301+
302+
Like `TestLowerCPUAllocation`, but the device type is taken from
303+
the function attribute. The `AttrStmt` can override the device
304+
type for allocations within its scope, but it defaults to the
305+
function's target.
306+
"""
307+
308+
transform = tvm.tir.transform.LowerTVMBuiltin()
309+
310+
def before():
311+
T.func_attr({"target": T.target("llvm")})
281312
ptr = T.allocate([16], "float32")
282313
buf = T.decl_buffer(16, "float32", data=ptr)
283314
buf[0] = 0.0
284315

285-
expected = tvm.TVMError
316+
expected = before
286317

287318

288319
if __name__ == "__main__":

0 commit comments

Comments
 (0)