Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ class Target : public ObjectRef {
*/
static Target WithHost(const Target& target, const Target& host);

/*! \return The target with the host stripped out */
Target WithoutHost() const;

/*!
* \brief Returns true if \p this target represents an external codegen. If so,
* \p this->kind->name can be used as the "Compiler" attribute on partitioned functions,
Expand Down
10 changes: 10 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,16 @@ Map<String, ObjectRef> TargetNode::Export() const {

Optional<Target> TargetNode::GetHost() const { return this->host.as<Target>(); }

Target Target::WithoutHost() const {
if ((*this)->GetHost()) {
auto output = make_object<TargetNode>(*get());
output->host = NullOpt;
return Target(output);
} else {
return *this;
}
}

int TargetNode::GetTargetDeviceType() const {
if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
return Downcast<Integer>(device_type)->value;
Expand Down
3 changes: 3 additions & 0 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace tir {

static constexpr const char* kDeviceContextVar = "device_api_context";

namespace {
class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
Expand Down Expand Up @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator {
bool made_change_{false};
};

} // namespace

inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
Expand Down
4 changes: 4 additions & 0 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
namespace tvm {
namespace tir {

namespace {

class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const std::unordered_set<const GlobalVarNode*>& external_methods,
Expand Down Expand Up @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator {
bool made_change_{false};
};

} // namespace

PrimFunc MakeUnpackedAPI(PrimFunc func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
Expand Down
30 changes: 25 additions & 5 deletions src/tir/transforms/primfunc_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,32 @@ namespace tvm {
namespace tir {
namespace transform {
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) {
return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)),
tvm::attr::kTarget, target->host.value_or(Target("llvm")));
Target without_host = target.WithoutHost();
Target target_host = Downcast<Target>(target->host.value_or(Target("llvm")));

auto fpass = [target, target_host, without_host](tir::PrimFunc func, IRModule m,
transform::PassContext ctx) {
bool is_externally_exposed = func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();

if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
auto func_target_host = func_target.value()->GetHost();
auto target_host = target->GetHost();

if (target_host && !func_target_host && is_externally_exposed) {
auto new_target = Target::WithHost(func_target.value(), target_host.value());
func = WithAttr(std::move(func), tvm::attr::kTarget, new_target);
}
} else if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
} else if (is_externally_exposed) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target);
} else {
func = WithAttr(std::move(func), tvm::attr::kTarget, without_host);
}
return WithAttr(std::move(f), tvm::attr::kTarget, target);

func = WithoutAttr(std::move(func), tvm::tir::attr::kIsHostFunc);

return func;
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
}
Expand Down
112 changes: 112 additions & 0 deletions tests/python/unittest/test_tir_transform_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,118 @@ def test_bind_target():
assert after["func2"].attrs["target"] == target


class TestBindTarget(tvm.testing.CompareBeforeAfter):
"""BindTarget adds the "target" attribute"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))

def before():
T.evaluate(0)

def expected():
T.func_attr({"target": T.target("cuda")})
T.evaluate(0)


class TestBindTargetWithHostToExposedFunction(tvm.testing.CompareBeforeAfter):
"""BindTarget adds the host target to externally-exposed functions"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))

def before():
T.func_attr({"global_symbol": "main"})
T.evaluate(0)

def expected():
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")})
T.evaluate(0)


class TestBindTargetWithHostToInternalFunction(tvm.testing.CompareBeforeAfter):
"""Internal functions have a target annotation, but without the host

The host portion of the target annotation provides host
parameters, and is used to expose a function externally as part of
`MakePackedAPI` and `MakeUnpackedAPI`. For internal functions, no
external exposure is required, so the host attribute should not be
used.
"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))

def before():
T.evaluate(0)

def expected():
T.func_attr({"target": T.target("cuda")})
T.evaluate(0)


class TestBindTargetIgnoresExisting(tvm.testing.CompareBeforeAfter):
"""BindTarget should not replace existing annotations"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))

def before():
T.func_attr({"target": T.target("nvptx")})
T.evaluate(0)

expected = before


class TestBindTargetUpdatesHost(tvm.testing.CompareBeforeAfter):
"""BindTarget should update host for existing annotations"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm -opt-level=0"))

def before():
T.func_attr({"global_symbol": "func", "target": T.target("nvptx")})
T.evaluate(0)

def expected():
T.func_attr(
{
"global_symbol": "func",
"target": T.target("nvptx", host="llvm -opt-level=0"),
}
)
T.evaluate(0)


class TestBindTargetMultipleFunctions(tvm.testing.CompareBeforeAfter):
"""BindTarget may apply to multiple functions in a module"""

transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))

def before(self):
@tvm.script.ir_module
class mod:
@T.prim_func
def func1():
T.evaluate(0)

@T.prim_func
def func2():
T.evaluate(0)

return mod

def expected(self):
@tvm.script.ir_module
class mod:
@T.prim_func
def func1():
T.func_attr({"target": T.target("cuda")})
T.evaluate(0)

@T.prim_func
def func2():
T.func_attr({"target": T.target("cuda")})
T.evaluate(0)

return mod


def test_filter_primfunc():
mod = MockModule
assert mod
Expand Down