diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 25c10dd6828d..6b681c07e5d5 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -715,8 +715,13 @@ Pass ConvertSSA() { tir::IRConvertSSA converter; Map functions; bool made_change = false; + // FIXME: This is just a temporal workaround to ensure free vars + // in device function have the same pointer as the host function for (auto [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { + if (!ptr->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + continue; + } auto updated = converter.VisitPrimFunc(GetRef(ptr)); if (!updated.same_as(base_func)) { made_change = true; @@ -725,6 +730,19 @@ Pass ConvertSSA() { } functions.Set(gvar, base_func); } + for (auto [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as()) { + if (ptr->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + continue; + } + auto updated = converter.VisitPrimFunc(GetRef(ptr)); + if (!updated.same_as(base_func)) { + made_change = true; + base_func = updated; + } + functions.Set(gvar, base_func); + } + } if (made_change) { mod.CopyOnWrite()->functions = std::move(functions); }