Skip to content

Commit 0abe9f7

Browse files
committed
[TIR] Cleanup of MakePackedAPI
Prior to this commit, the `RequiresPackedAPI` function checked whether a function needed the packed func API. This was used both to generate a list of call-sites to update, and as part of the updates to `PrimFunc` signatures. However, the function that updates the `PrimFunc` signature could still return the original function unmodified, breaking internal method calls. This occurred for functions with a `kTarget` attribute without a host. This commit updates `MakePackedAPI` to first update all `PrimFunc` signatures that require the packed func API, then use the result to determine which call-sites must be updated. This resolves the discrepancy for host-less target annotations, and removes the possibility of similar discrepancies in the future.
1 parent d866f5c commit 0abe9f7

File tree

1 file changed

+29
-39
lines changed

1 file changed

+29
-39
lines changed

src/tir/transforms/make_packed_api.cc

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -181,33 +181,17 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
181181
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
182182
}
183183

184-
/* \brief Return the global_symbol of the function, if it should be updated
185-
*
186-
* \param func The function to be inspected
187-
*
188-
* \returns The global_symbol to be used for the function at call
189-
* sites, or NullOpt if the function is to remain unchanged.
190-
*/
191-
Optional<String> RequiresPackedAPI(const PrimFunc& func) {
184+
PrimFunc MakePackedAPI(PrimFunc func) {
192185
// A function with an explicit calling convention has already been
193186
// lowered, and should not be modified.
194187
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
195188
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
196-
return NullOpt;
189+
return func;
197190
}
198191
}
199192

200193
// Internal function calls do not need the PackedFunc API
201194
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
202-
if (!global_symbol.defined()) {
203-
return NullOpt;
204-
}
205-
206-
return global_symbol;
207-
}
208-
209-
PrimFunc MakePackedAPI(PrimFunc func) {
210-
auto global_symbol = RequiresPackedAPI(func);
211195
if (!global_symbol.defined()) {
212196
return func;
213197
}
@@ -216,7 +200,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
216200
Target target = [&]() {
217201
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
218202
ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
219-
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
203+
<< tvm::attr::kTarget << "), but the function " << name_hint
204+
<< " only has attributes" << func->attrs;
220205
return opt.value();
221206
}();
222207
int target_device_type = target->GetTargetDeviceType();
@@ -375,38 +360,43 @@ namespace transform {
375360
Pass MakePackedAPI() {
376361
auto pass_func = [](IRModule mod, PassContext ctx) {
377362
Map<GlobalVar, String> packed_func_methods;
378-
for (const auto& [gvar, base_func] : mod->functions) {
379-
if (auto opt = base_func.as<PrimFunc>()) {
380-
auto prim_func = opt.value();
381-
if (auto global_symbol = RequiresPackedAPI(prim_func)) {
382-
packed_func_methods.Set(gvar, global_symbol.value());
383-
}
384-
}
385-
}
386363

387-
IRModuleNode* mptr = mod.CopyOnWrite();
388364
IRModule updates;
389-
390-
for (const auto& [gvar, base_func] : mptr->functions) {
365+
for (const auto& [gvar, base_func] : mod->functions) {
391366
if (auto opt = base_func.as<PrimFunc>()) {
392-
auto func = opt.value();
393-
auto orig_func = func;
394-
395-
if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
396-
func.CopyOnWrite()->body = body.value();
397-
}
398-
399-
func = MakePackedAPI(std::move(func));
367+
auto orig_func = opt.value();
368+
auto func = MakePackedAPI(orig_func);
400369

401370
if (!func.same_as(orig_func)) {
402371
updates->Add(gvar, func);
372+
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
373+
packed_func_methods.Set(gvar, global_symbol);
403374
}
404375
}
405376
}
406-
407377
if (updates->functions.size()) {
408378
mod.CopyOnWrite()->Update(updates);
409379
}
380+
381+
if (packed_func_methods.size()) {
382+
IRModule updates;
383+
for (const auto& [gvar, base_func] : mod->functions) {
384+
if (auto opt = base_func.as<PrimFunc>()) {
385+
auto func = opt.value();
386+
auto orig_func = func;
387+
388+
if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
389+
func.CopyOnWrite()->body = body.value();
390+
updates->Add(gvar, func);
391+
}
392+
}
393+
}
394+
395+
if (updates->functions.size()) {
396+
mod.CopyOnWrite()->Update(updates);
397+
}
398+
}
399+
410400
return mod;
411401
};
412402

0 commit comments

Comments
 (0)