Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ namespace tir {
// These information are needed during codegen.
class BuiltinLower : public StmtExprMutator {
public:
static PrimFunc Build(PrimFunc func) {
Optional<PrimExpr> device_type = NullOpt;
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
device_type = Integer(target.value()->kind->default_device_type);
}

BuiltinLower mutator(device_type);
func.CopyOnWrite()->body = mutator.VisitBodyAndRealizeAlloca(func->body);
return func;
}

explicit BuiltinLower(Optional<PrimExpr> device_type = NullOpt) : device_type_(device_type) {}

// NOTE: Right now, we make the following scoping requirement
// for memory allocated by the following primitives
// - tvm_stack_make_array
Expand Down Expand Up @@ -284,13 +297,17 @@ class BuiltinLower : public StmtExprMutator {

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_id) {
ICHECK(!device_id_);
auto cache = device_id_;
device_id_ = op->value;
return this->VisitStmt(op->body);
Stmt out = this->VisitStmt(op->body);
device_id_ = cache;
return out;
} else if (op->attr_key == attr::device_type) {
ICHECK(!device_type_);
auto cache = device_type_;
device_type_ = op->value;
return this->VisitStmt(op->body);
Stmt out = this->VisitStmt(op->body);
device_type_ = cache;
return out;
} else {
return StmtExprMutator::VisitStmt_(op);
}
Expand Down Expand Up @@ -656,13 +673,12 @@ class BuiltinLower : public StmtExprMutator {
namespace transform {

Pass LowerTVMBuiltin() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
if (IsHostFunc(f).value_or(false)) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
f.CopyOnWrite()->body = BuiltinLower().Build(f->body);
VLOG(2) << "LowerTVMBuiltin: " << f;
auto pass_func = [](PrimFunc func, IRModule m, PassContext ctx) {
if (IsHostFunc(func).value_or(false)) {
func = BuiltinLower::Build(func);
VLOG(2) << "LowerTVMBuiltin: " << func;
}
return f;
return func;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
}
Expand Down
41 changes: 36 additions & 5 deletions tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter):

def before():
T.func_attr({"target": T.target("llvm")})
T.attr("dummy", "device_type", 2) # kDLCuda
T.attr("dummy", "device_type", tvm.runtime.Device.kDLCUDA)
T.attr("dummy", "device_id", 0)
ptr = T.allocate([16], "float32")
buf = T.decl_buffer(16, "float32", data=ptr)
Expand Down Expand Up @@ -246,7 +246,7 @@ class TestLowerCPUAllocation(tvm.testing.CompareBeforeAfter):

def before():
T.func_attr({"target": T.target("llvm")})
T.attr("dummy", "device_type", 1) # kDLCPU
T.attr("dummy", "device_type", tvm.runtime.Device.kDLCPU)
T.attr("dummy", "device_id", 0)
ptr = T.allocate([16], "float32")
buf = T.decl_buffer(16, "float32", data=ptr)
Expand All @@ -260,11 +260,13 @@ def expected():


class TestLowerAllocateRequiresDeviceID(tvm.testing.CompareBeforeAfter):
"""If device id is missing, error."""

transform = tvm.tir.transform.LowerTVMBuiltin()

def before():
T.func_attr({"target": T.target("llvm")})
T.attr("dummy", "device_id", 0)
T.attr("dummy", "device_type", tvm.runtime.Device.kDLCUDA)
ptr = T.allocate([16], "float32")
buf = T.decl_buffer(16, "float32", data=ptr)
buf[0] = 0.0
Expand All @@ -273,16 +275,45 @@ def before():


class TestLowerAllocateRequiresDeviceType(tvm.testing.CompareBeforeAfter):
"""If device type is missing, error.

The device type can be inferred either from the `"device_type"`
statement attribute, or from the `"target"` function attribute.
Here, we provide neither. The `"tir.is_host_func"` attribute is
provided as otherwise the function would be skipped altogether by
LowerTVMBuiltin.
"""

transform = tvm.tir.transform.LowerTVMBuiltin()

def before():
T.func_attr({"target": T.target("llvm")})
T.func_attr({"tir.is_host_func": True})
T.attr("dummy", "device_id", 0)
ptr = T.allocate([1024 * 1024], "float32")
buf = T.decl_buffer(1024 * 1024, "float32", data=ptr)
buf[0] = 0.0

expected = tvm.TVMError


class TestLowerCPUAllocWithFunctionAttr(tvm.testing.CompareBeforeAfter):
"""CPU allocations can be handled at codegen time

Like `TestLowerCPUAllocation`, but the device type is taken from
the function attribute. The `AttrStmt` can override the device
type for allocations within its scope, but it defaults to the
function's target.
"""

transform = tvm.tir.transform.LowerTVMBuiltin()

def before():
T.func_attr({"target": T.target("llvm")})
ptr = T.allocate([16], "float32")
buf = T.decl_buffer(16, "float32", data=ptr)
buf[0] = 0.0

expected = tvm.TVMError
expected = before


if __name__ == "__main__":
Expand Down