Skip to content

Commit 36c0359

Browse files
committed
[IR] Expose ReplaceGlobalVars utility in the Python API
This is a follow-up PR to #17202, which added a general utility to replace `GlobalVar` instances across all TVM IR types. This PR exposes this new utility through the Python API, and explicitly tests its functionality.
1 parent d7e0af2 commit 36c0359

File tree

7 files changed

+414
-12
lines changed

7 files changed

+414
-12
lines changed

include/tvm/ir/replace_global_var.h renamed to include/tvm/ir/replace_global_vars.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace transform {
4141
*
4242
* \return The updated IRModule
4343
*/
44-
TVM_DLL IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements);
44+
TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, Map<GlobalVar, GlobalVar> replacements);
4545

4646
struct GlobalVarReplacer {
4747
using FType = NodeFunctor<BaseFunc(const ObjectRef&, Map<GlobalVar, GlobalVar>)>;

python/tvm/ir/module.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""IRModule that holds the functions and type definitions."""
18+
1819
from __future__ import annotations
1920

2021
from typing import Dict, Union
@@ -216,6 +217,33 @@ def get_global_vars(self):
216217
"""
217218
return _ffi_api.Module_GetGlobalVars(self)
218219

220+
def replace_global_vars(
221+
self,
222+
replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]],
223+
) -> "IRModule":
224+
"""Replace GlobalVar instances within the module
225+
226+
Replace GlobalVars within the IRModule. Since the IRModule
227+
may contain internal references to a GlobalVar, either in TIR
228+
or in Relax, this method should be used whenever replacing or
229+
renaming a GlobalVar.
230+
231+
Parameters
232+
----------
233+
replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]]
234+
235+
A dictionary where each key is a GlobalVar to be replaced,
236+
and the corresponding value is the GlobalVar with which to
237+
replace it.
238+
239+
Returns
240+
-------
241+
IRModule
242+
The updated module
243+
244+
"""
245+
return _ffi_api.Module_ReplaceGlobalVars(self, replacements)
246+
219247
def get_global_type_vars(self):
220248
"""Collect all global type vars defined in this module.
221249

src/ir/replace_global_var.cc renamed to src/ir/replace_global_vars.cc

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,22 @@
1818
*/
1919

2020
/*!
21-
* \file src/ir/replace_global_var.cc
21+
* \file src/ir/replace_global_vars.cc
2222
* \brief IRModule transform to replace GlobalVar instances across any IR type.
2323
*/
2424

25-
#include <tvm/ir/replace_global_var.h>
25+
#include <tvm/ir/replace_global_vars.h>
2626

2727
#include <vector>
2828

2929
namespace tvm {
3030
namespace transform {
3131

32-
IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements) {
32+
IRModule ReplaceGlobalVars(IRModule mod, Map<GlobalVar, GlobalVar> replacements) {
33+
if (replacements.empty()) {
34+
return mod;
35+
}
36+
3337
std::vector<GlobalVar> to_remove;
3438
IRModule updates;
3539

@@ -57,7 +61,38 @@ IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements)
5761
return mod;
5862
}
5963

60-
TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVar").set_body_typed(ReplaceGlobalVar);
64+
TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars);
65+
66+
IRModule ModuleReplaceGlobalVars(
67+
IRModule mod, Map<Variant<String, GlobalVar>, Variant<String, GlobalVar>> replacements) {
68+
Map<GlobalVar, GlobalVar> gvar_replacements;
69+
for (const auto& [before, after] : replacements) {
70+
GlobalVar gvar_before;
71+
if (auto gvar = before.as<GlobalVar>()) {
72+
gvar_before = gvar.value();
73+
} else if (auto str = before.as<String>()) {
74+
gvar_before = mod->GetGlobalVar(str.value());
75+
} else {
76+
LOG(FATAL) << "Variant<String,GlobalVar> must contain either String or GlobalVar";
77+
}
78+
79+
GlobalVar gvar_after;
80+
if (auto gvar = after.as<GlobalVar>()) {
81+
gvar_after = gvar.value();
82+
} else if (auto str = after.as<String>()) {
83+
gvar_after = gvar_before;
84+
gvar_after.CopyOnWrite()->name_hint = str.value();
85+
} else {
86+
LOG(FATAL) << "Variant<String,GlobalVar> must contain either String or GlobalVar";
87+
}
88+
89+
gvar_replacements.Set(gvar_before, gvar_after);
90+
}
91+
92+
return ReplaceGlobalVars(mod, gvar_replacements);
93+
}
94+
95+
TVM_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars);
6196

6297
} // namespace transform
6398
} // namespace tvm

src/relax/transform/attach_global_symbol.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
*/
2323

2424
#include <tvm/ir/module.h>
25-
#include <tvm/ir/replace_global_var.h>
25+
#include <tvm/ir/replace_global_vars.h>
2626
#include <tvm/relax/struct_info.h>
2727
#include <tvm/relax/transform.h>
2828
#include <tvm/tir/function.h>
@@ -72,7 +72,7 @@ Pass AttachGlobalSymbol() {
7272
mod.CopyOnWrite()->Update(updates);
7373

7474
if (gvar_updates.size()) {
75-
mod = tvm::transform::ReplaceGlobalVar(mod, gvar_updates);
75+
mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates);
7676
}
7777
}
7878
return mod;

src/relax/transform/replace_global_var.cc renamed to src/relax/transform/replace_global_vars.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
/*!
2121
*
22-
* \file src/relax/transform/replace_global_var.cc
22+
* \file src/relax/transform/replace_global_vars.cc
2323
*
2424
* \brief GlobalVar replacement across IR types
2525
*/
2626

2727
#include <tvm/ir/analysis.h>
28-
#include <tvm/ir/replace_global_var.h>
28+
#include <tvm/ir/replace_global_vars.h>
2929
#include <tvm/relax/analysis.h>
3030
#include <tvm/relax/expr_functor.h>
3131
#include <tvm/tir/expr_functor.h>
@@ -53,7 +53,24 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
5353
.set_dispatch<relax::FunctionNode>([](const ObjectRef& func,
5454
Map<GlobalVar, GlobalVar> replacements) -> BaseFunc {
5555
Mutator mutator(replacements);
56-
return Downcast<BaseFunc>(mutator(Downcast<Function>(func)));
56+
auto new_func = Downcast<Function>(mutator(Downcast<Function>(func)));
57+
58+
// If the function is externally exposed, and is being replaced
59+
// by a GlobalVar with a new name, then the function's
60+
// kGlobalSymbol must be updated to match.
61+
if (auto opt = new_func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
62+
auto name = opt.value();
63+
for (const auto& [before, after] : replacements) {
64+
if (before->name_hint == name) {
65+
if (after->name_hint != name) {
66+
new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, after->name_hint);
67+
}
68+
break;
69+
}
70+
}
71+
}
72+
73+
return new_func;
5774
});
5875

5976
TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)

src/tir/transforms/replace_global_var.cc renamed to src/tir/transforms/replace_global_vars.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
/*!
2121
*
22-
* \file src/tir/transforms/replace_global_var.cc
22+
* \file src/tir/transforms/replace_global_vars.cc
2323
*
2424
* \brief GlobalVar replacement across IR types
2525
*/
2626

27-
#include <tvm/ir/replace_global_var.h>
27+
#include <tvm/ir/replace_global_vars.h>
2828
#include <tvm/tir/function.h>
2929
#include <tvm/tir/stmt_functor.h>
3030

@@ -61,6 +61,22 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
6161
if (!new_body.same_as(func->body)) {
6262
func.CopyOnWrite()->body = new_body;
6363
}
64+
65+
// If the function is externally exposed, and is being replaced
66+
// by a GlobalVar with a new name, then the function's
67+
// kGlobalSymbol must be updated to match.
68+
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
69+
auto name = opt.value();
70+
for (const auto& [before, after] : replacements) {
71+
if (before->name_hint == name) {
72+
if (after->name_hint != name) {
73+
func = WithAttr(func, tvm::attr::kGlobalSymbol, after->name_hint);
74+
}
75+
break;
76+
}
77+
}
78+
}
79+
6480
return func;
6581
});
6682

0 commit comments

Comments
 (0)