Skip to content

Commit 81056cc

Browse files
authored
[TIR] Preserve existing kTarget function attribute in BindTarget (#14942)
* [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs #14913 and #14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. * [Target] Added WithoutHost method * [TIR] Preserve existing kTarget function attribute in BindTarget Previously, if a function already has a `tvm::attr::kTarget` attribute, it will be overwritten by the `tir.BindTarget` transform. This commit updates the behavior such that `tir.BindTarget` adds annotations to functions that are missing a target annotation, but preserves any existing target annotations. This is part of a series of commits to simplify the handling of multi-target builds.
1 parent 86ba26d commit 81056cc

File tree

4 files changed

+150
-5
lines changed

4 files changed

+150
-5
lines changed

include/tvm/target/target.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ class Target : public ObjectRef {
228228
*/
229229
static Target WithHost(const Target& target, const Target& host);
230230

231+
/*! \return The target with the host stripped out */
232+
Target WithoutHost() const;
233+
231234
/*!
232235
* \brief Returns true if \p this target represents an external codegen. If so,
233236
* \p this->kind->name can be used as the "Compiler" attribute on partitioned functions,

src/target/target.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,16 @@ Map<String, ObjectRef> TargetNode::Export() const {
662662

663663
Optional<Target> TargetNode::GetHost() const { return this->host.as<Target>(); }
664664

665+
Target Target::WithoutHost() const {
666+
if ((*this)->GetHost()) {
667+
auto output = make_object<TargetNode>(*get());
668+
output->host = NullOpt;
669+
return Target(output);
670+
} else {
671+
return *this;
672+
}
673+
}
674+
665675
int TargetNode::GetTargetDeviceType() const {
666676
if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
667677
return Downcast<Integer>(device_type)->value;

src/tir/transforms/primfunc_utils.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,32 @@ namespace tvm {
3030
namespace tir {
3131
namespace transform {
3232
transform::Pass BindTarget(Target target) {
33-
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
34-
if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) {
35-
return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)),
36-
tvm::attr::kTarget, target->host.value_or(Target("llvm")));
33+
Target without_host = target.WithoutHost();
34+
Target target_host = Downcast<Target>(target->host.value_or(Target("llvm")));
35+
36+
auto fpass = [target, target_host, without_host](tir::PrimFunc func, IRModule m,
37+
transform::PassContext ctx) {
38+
bool is_externally_exposed = func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
39+
40+
if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
41+
auto func_target_host = func_target.value()->GetHost();
42+
auto target_host = target->GetHost();
43+
44+
if (target_host && !func_target_host && is_externally_exposed) {
45+
auto new_target = Target::WithHost(func_target.value(), target_host.value());
46+
func = WithAttr(std::move(func), tvm::attr::kTarget, new_target);
47+
}
48+
} else if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
49+
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
50+
} else if (is_externally_exposed) {
51+
func = WithAttr(std::move(func), tvm::attr::kTarget, target);
52+
} else {
53+
func = WithAttr(std::move(func), tvm::attr::kTarget, without_host);
3754
}
38-
return WithAttr(std::move(f), tvm::attr::kTarget, target);
55+
56+
func = WithoutAttr(std::move(func), tvm::tir::attr::kIsHostFunc);
57+
58+
return func;
3959
};
4060
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
4161
}

tests/python/unittest/test_tir_transform_helpers.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,118 @@ def test_bind_target():
8585
assert after["func2"].attrs["target"] == target
8686

8787

88+
class TestBindTarget(tvm.testing.CompareBeforeAfter):
89+
"""BindTarget adds the "target" attribute"""
90+
91+
transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))
92+
93+
def before():
94+
T.evaluate(0)
95+
96+
def expected():
97+
T.func_attr({"target": T.target("cuda")})
98+
T.evaluate(0)
99+
100+
101+
class TestBindTargetWithHostToExposedFunction(tvm.testing.CompareBeforeAfter):
102+
"""BindTarget adds the host target to externally-exposed functions"""
103+
104+
transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))
105+
106+
def before():
107+
T.func_attr({"global_symbol": "main"})
108+
T.evaluate(0)
109+
110+
def expected():
111+
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")})
112+
T.evaluate(0)
113+
114+
115+
class TestBindTargetWithHostToInternalFunction(tvm.testing.CompareBeforeAfter):
116+
"""Internal functions have a target annotation, but without the host
117+
118+
The host portion of the target annotation provides host
119+
parameters, and is used to expose a function externally as part of
120+
`MakePackedAPI` and `MakeUnpackedAPI`. For internal functions, no
121+
external exposure is required, so the host attribute should not be
122+
used.
123+
"""
124+
125+
transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))
126+
127+
def before():
128+
T.evaluate(0)
129+
130+
def expected():
131+
T.func_attr({"target": T.target("cuda")})
132+
T.evaluate(0)
133+
134+
135+
class TestBindTargetIgnoresExisting(tvm.testing.CompareBeforeAfter):
136+
"""BindTarget should not replace existing annotations"""
137+
138+
transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))
139+
140+
def before():
141+
T.func_attr({"target": T.target("nvptx")})
142+
T.evaluate(0)
143+
144+
expected = before
145+
146+
147+
class TestBindTargetUpdatesHost(tvm.testing.CompareBeforeAfter):
148+
"""BindTarget should update host for existing annotations"""
149+
150+
transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm -opt-level=0"))
151+
152+
def before():
153+
T.func_attr({"global_symbol": "func", "target": T.target("nvptx")})
154+
T.evaluate(0)
155+
156+
def expected():
157+
T.func_attr(
158+
{
159+
"global_symbol": "func",
160+
"target": T.target("nvptx", host="llvm -opt-level=0"),
161+
}
162+
)
163+
T.evaluate(0)
164+
165+
166+
class TestBindTargetMultipleFunctions(tvm.testing.CompareBeforeAfter):
167+
"""BindTarget may apply to multiple functions in a module"""
168+
169+
transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))
170+
171+
def before(self):
172+
@tvm.script.ir_module
173+
class mod:
174+
@T.prim_func
175+
def func1():
176+
T.evaluate(0)
177+
178+
@T.prim_func
179+
def func2():
180+
T.evaluate(0)
181+
182+
return mod
183+
184+
def expected(self):
185+
@tvm.script.ir_module
186+
class mod:
187+
@T.prim_func
188+
def func1():
189+
T.func_attr({"target": T.target("cuda")})
190+
T.evaluate(0)
191+
192+
@T.prim_func
193+
def func2():
194+
T.func_attr({"target": T.target("cuda")})
195+
T.evaluate(0)
196+
197+
return mod
198+
199+
88200
def test_filter_primfunc():
89201
mod = MockModule
90202
assert mod

0 commit comments

Comments
 (0)