Skip to content

Commit 4c2dd50

Browse files
committed
Annotate entire body as target region if no subregions found
Otherwise, in cases of a custom codegen, the device specification may be dropped entirely.
1 parent 84cb6fe commit 4c2dd50

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

src/tir/transforms/annotate_device_regions.cc

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,37 @@ namespace tir {
3434

3535
class DeviceRegionAnnotater : public StmtMutator {
3636
public:
37+
static Stmt Apply(Target host_target, Target device_target, Stmt body) {
38+
DeviceRegionAnnotater mutator(device_target);
39+
body = mutator(body);
40+
41+
bool same_host_and_device = host_target->str() == device_target->str();
42+
43+
// If no region was found that must be on the device, but the
44+
// device and host differ (e.g. `T.target('c', host='llvm')`),
45+
// then the entire region should be annotated. This preserves the
46+
// host-side handling of DLTensor arguments, while ensuring that
47+
// any device targets are used for the codegen.
48+
if (!mutator.found_target_region_ && !same_host_and_device) {
49+
body = AttrStmt(device_target, tvm::attr::kTarget, 0, body);
50+
}
51+
52+
return body;
53+
}
54+
55+
private:
3756
explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {}
3857

3958
Stmt VisitStmt_(const AttrStmtNode* op) final {
4059
if (op->attr_key == tvm::attr::kTarget) {
4160
// If a target attribute already exists, use it as-is.
61+
found_target_region_ = true;
4262
return GetRef<Stmt>(op);
4363
} else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
4464
op->attr_key == attr::device_scope) {
4565
// These attributes are only allowed in device-side code, so
4666
// they should be annotated with the function's default target.
67+
found_target_region_ = true;
4768
Stmt body = GetRef<Stmt>(op);
4869
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
4970
} else {
@@ -52,8 +73,8 @@ class DeviceRegionAnnotater : public StmtMutator {
5273
}
5374
}
5475

55-
private:
5676
Target device_target_;
77+
bool found_target_region_{false};
5778
};
5879

5980
namespace transform {
@@ -64,9 +85,12 @@ Pass AnnotateDeviceRegions() {
6485
ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute";
6586
Target target = opt_target.value();
6687

67-
if (target->GetHost()) {
68-
DeviceRegionAnnotater mutator(target.WithoutHost());
69-
func.CopyOnWrite()->body = mutator(func->body);
88+
if (auto opt_host = target->GetHost()) {
89+
auto new_body =
90+
DeviceRegionAnnotater::Apply(opt_host.value(), target.WithoutHost(), func->body);
91+
if (!new_body.same_as(func->body)) {
92+
func.CopyOnWrite()->body = new_body;
93+
}
7094
}
7195
return func;
7296
};

0 commit comments

Comments
 (0)