Skip to content

Commit 7131411

Browse files
authored
[Bugfix][TIR][VTA] Update host-side target, even without device func (#14982)
This resolves an issue introduced by the combination of #14918 and #14945. The bug occurred for targets that do not require device-side codegen, but do require a `device_type` other than `kDLCPU`. It wasn't caught by CI, as the issue only occurred with the combination of both PRs. 1. #14918 updated `SplitHostDevice` to only modify the `"target"` attribute when a device-side function has been extracted. 2. For VTA, there is no device-side function, as everything is done through host-side API calls. 3. From (1) and (2), the VTA examples kept the target `T.target("ext_dev", host="llvm")` after the `SplitHostDevice` pass, instead of being updated to `T.target("llvm")`. 4. #14945 restricted CombineContextCall to only apply to host-side passes. 5. From (4) and (5), the `CombineContextCall` pass was no longer applied to the VTA context calls. This PR fixes `SplitHostDevice`, updating the target from `T.target("ext_dev", host="llvm")` to `T.target("llvm")`, even if no device sections have been extracted from the function.
1 parent 4267fbf commit 7131411

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

src/tir/transforms/split_host_device.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g
108108

109109
HostDeviceSplitter splitter(device_mod, name_prefix);
110110

111-
auto body = splitter(func->body);
112-
113-
if (!body.same_as(func->body)) {
111+
if (auto body = splitter(func->body); !body.same_as(func->body)) {
114112
func.CopyOnWrite()->body = body;
115-
auto target_host = target->GetHost().value_or(Target("llvm"));
116-
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
113+
}
114+
115+
if (auto target_host = target->GetHost()) {
116+
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
117117
}
118118

119119
return func;

tests/python/unittest/test_tir_transform_split_host_device.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,5 +168,21 @@ def main_kernel(n: T.int32):
168168
return mod
169169

170170

171+
class TestSplitHostDevice(BaseCompare):
172+
"""Like TestSplitHostDevice, but no device regions to extract
173+
174+
Even if there are no device regions, the host-side function should
175+
still have its "target" attribute updated.
176+
"""
177+
178+
def before():
179+
T.func_attr({"target": T.target("ext_dev", host="llvm")})
180+
T.evaluate(0)
181+
182+
def expected():
183+
T.func_attr({"target": T.target("llvm")})
184+
T.evaluate(0)
185+
186+
171187
if __name__ == "__main__":
172188
tvm.testing.main()

0 commit comments

Comments
 (0)