Skip to content

Commit b4025fd

Browse files
committed
[TIR] Cleanup of MakePackedAPI
Prior to this commit, the `RequiresPackedAPI` function checked whether a function needed the packed func API. This was used both to generate a list of call-sites to update, and as part of the updates to `PrimFunc` signatures. However, the function that updates the `PrimFunc` signature could still return the original function unmodified, breaking internal method calls. This occurred for functions with a `kTarget` attribute without a host. This commit updates `MakePackedAPI` to first update all `PrimFunc` signatures that require the packed func API, then use the result to determine which call-sites must be updated. This resolves the discrepancy for host-less target annotations, and removes the possibility of similar discrepancies in the future.
1 parent ab3faea commit b4025fd

File tree

1 file changed

+29
-39
lines changed

1 file changed

+29
-39
lines changed

src/tir/transforms/make_packed_api.cc

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -183,33 +183,17 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
183183
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
184184
}
185185

186-
/* \brief Return the global_symbol of the function, if it should be updated
187-
*
188-
* \param func The function to be inspected
189-
*
190-
* \returns The global_symbol to be used for the function at call
191-
* sites, or NullOpt if the function is to remain unchanged.
192-
*/
193-
Optional<String> RequiresPackedAPI(const PrimFunc& func) {
186+
PrimFunc MakePackedAPI(PrimFunc func) {
194187
// A function with an explicit calling convention has already been
195188
// lowered, and should not be modified.
196189
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
197190
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
198-
return NullOpt;
191+
return func;
199192
}
200193
}
201194

202195
// Internal function calls do not need the PackedFunc API
203196
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
204-
if (!global_symbol.defined()) {
205-
return NullOpt;
206-
}
207-
208-
return global_symbol;
209-
}
210-
211-
PrimFunc MakePackedAPI(PrimFunc func) {
212-
auto global_symbol = RequiresPackedAPI(func);
213197
if (!global_symbol.defined()) {
214198
return func;
215199
}
@@ -218,7 +202,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
218202
Target target = [&]() {
219203
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
220204
ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
221-
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
205+
<< tvm::attr::kTarget << "), but the function " << name_hint
206+
<< " only has attributes" << func->attrs;
222207
return opt.value();
223208
}();
224209
int target_device_type = target->GetTargetDeviceType();
@@ -377,38 +362,43 @@ namespace transform {
377362
Pass MakePackedAPI() {
378363
auto pass_func = [](IRModule mod, PassContext ctx) {
379364
Map<GlobalVar, String> packed_func_methods;
380-
for (const auto& [gvar, base_func] : mod->functions) {
381-
if (auto opt = base_func.as<PrimFunc>()) {
382-
auto prim_func = opt.value();
383-
if (auto global_symbol = RequiresPackedAPI(prim_func)) {
384-
packed_func_methods.Set(gvar, global_symbol.value());
385-
}
386-
}
387-
}
388365

389-
IRModuleNode* mptr = mod.CopyOnWrite();
390366
IRModule updates;
391-
392-
for (const auto& [gvar, base_func] : mptr->functions) {
367+
for (const auto& [gvar, base_func] : mod->functions) {
393368
if (auto opt = base_func.as<PrimFunc>()) {
394-
auto func = opt.value();
395-
auto orig_func = func;
396-
397-
if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
398-
func.CopyOnWrite()->body = body.value();
399-
}
400-
401-
func = MakePackedAPI(std::move(func));
369+
auto orig_func = opt.value();
370+
auto func = MakePackedAPI(orig_func);
402371

403372
if (!func.same_as(orig_func)) {
404373
updates->Add(gvar, func);
374+
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
375+
packed_func_methods.Set(gvar, global_symbol);
405376
}
406377
}
407378
}
408-
409379
if (updates->functions.size()) {
410380
mod.CopyOnWrite()->Update(updates);
411381
}
382+
383+
if (packed_func_methods.size()) {
384+
IRModule updates;
385+
for (const auto& [gvar, base_func] : mod->functions) {
386+
if (auto opt = base_func.as<PrimFunc>()) {
387+
auto func = opt.value();
388+
auto orig_func = func;
389+
390+
if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
391+
func.CopyOnWrite()->body = body.value();
392+
updates->Add(gvar, func);
393+
}
394+
}
395+
}
396+
397+
if (updates->functions.size()) {
398+
mod.CopyOnWrite()->Update(updates);
399+
}
400+
}
401+
412402
return mod;
413403
};
414404

0 commit comments

Comments
 (0)