@@ -34,16 +34,37 @@ namespace tir {
3434
3535class DeviceRegionAnnotater : public StmtMutator {
3636 public:
37+ static Stmt Apply (Target host_target, Target device_target, Stmt body) {
38+ DeviceRegionAnnotater mutator (device_target);
39+ body = mutator (body);
40+
41+ bool same_host_and_device = host_target->str () == device_target->str ();
42+
43+ // If no region was found that must be on the device, but the
44+ // device and host differ (e.g. `T.target('c', host='llvm')`),
45+ // then the entire region should be annotated. This preserves the
46+ // host-side handling of DLTensor arguments, while ensuring that
47+ // any device targets are used for the codegen.
48+ if (!mutator.found_target_region_ && !same_host_and_device) {
49+ body = AttrStmt (device_target, tvm::attr::kTarget , 0 , body);
50+ }
51+
52+ return body;
53+ }
54+
55+ private:
3756 explicit DeviceRegionAnnotater (Target device_target) : device_target_(device_target) {}
3857
3958 Stmt VisitStmt_ (const AttrStmtNode* op) final {
4059 if (op->attr_key == tvm::attr::kTarget ) {
4160 // If a target attribute already exists, use it as-is.
61+ found_target_region_ = true ;
4262 return GetRef<Stmt>(op);
4363 } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
4464 op->attr_key == attr::device_scope) {
4565 // These attributes are only allowed in device-side code, so
4666 // they should be annotated with the function's default target.
67+ found_target_region_ = true ;
4768 Stmt body = GetRef<Stmt>(op);
4869 return AttrStmt (device_target_, tvm::attr::kTarget , 0 , body);
4970 } else {
@@ -52,8 +73,8 @@ class DeviceRegionAnnotater : public StmtMutator {
5273 }
5374 }
5475
55- private:
5676 Target device_target_;
77+ bool found_target_region_{false };
5778};
5879
5980namespace transform {
@@ -64,9 +85,12 @@ Pass AnnotateDeviceRegions() {
6485 ICHECK (opt_target) << " AnnotateDeviceRegions: Require the target attribute" ;
6586 Target target = opt_target.value ();
6687
67- if (target->GetHost ()) {
68- DeviceRegionAnnotater mutator (target.WithoutHost ());
69- func.CopyOnWrite ()->body = mutator (func->body );
88+ if (auto opt_host = target->GetHost ()) {
89+ auto new_body =
90+ DeviceRegionAnnotater::Apply (opt_host.value (), target.WithoutHost (), func->body );
91+ if (!new_body.same_as (func->body )) {
92+ func.CopyOnWrite ()->body = new_body;
93+ }
7094 }
7195 return func;
7296 };
0 commit comments