Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
*/

/*!
* \file tvm/ir/replace_global_var.h
* \file tvm/ir/replace_global_vars.h
*
* \brief A utility to replace GlobalVar instances across all TVM IR
* types in an IRMdoule.
*/
#ifndef TVM_IR_REPLACE_GLOBAL_VAR_H_
#define TVM_IR_REPLACE_GLOBAL_VAR_H_
#ifndef TVM_IR_REPLACE_GLOBAL_VARS_H_
#define TVM_IR_REPLACE_GLOBAL_VARS_H_

#include <tvm/ir/module.h>

Expand All @@ -41,7 +41,7 @@ namespace transform {
*
* \return The updated IRModule
*/
TVM_DLL IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements);
TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, Map<GlobalVar, GlobalVar> replacements);

struct GlobalVarReplacer {
using FType = NodeFunctor<BaseFunc(const ObjectRef&, Map<GlobalVar, GlobalVar>)>;
Expand All @@ -54,4 +54,4 @@ struct GlobalVarReplacer {
} // namespace transform
} // namespace tvm

#endif // TVM_IR_REPLACE_GLOBAL_VAR_H_
#endif // TVM_IR_REPLACE_GLOBAL_VARS_H_
28 changes: 28 additions & 0 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""IRModule that holds the functions and type definitions."""

from __future__ import annotations

from typing import Dict, Union
Expand Down Expand Up @@ -216,6 +217,33 @@ def get_global_vars(self):
"""
return _ffi_api.Module_GetGlobalVars(self)

def replace_global_vars(
self,
replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]],
) -> "IRModule":
"""Replace GlobalVar instances within the module

Replace GlobalVars within the IRModule. Since the IRModule
may contain internal references to a GlobalVar, either in TIR
or in Relax, this method should be used whenever replacing or
renaming a GlobalVar.

Parameters
----------
replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]]

A dictionary where each key is a GlobalVar to be replaced,
and the corresponding value is the GlobalVar with which to
replace it.

Returns
-------
IRModule
The updated module

"""
return _ffi_api.Module_ReplaceGlobalVars(self, replacements)

def get_global_type_vars(self):
"""Collect all global type vars defined in this module.

Expand Down
43 changes: 39 additions & 4 deletions src/ir/replace_global_var.cc → src/ir/replace_global_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@
*/

/*!
* \file src/ir/replace_global_var.cc
* \file src/ir/replace_global_vars.cc
* \brief IRModule transform to replace GlobalVar instances across any IR type.
*/

#include <tvm/ir/replace_global_var.h>
#include <tvm/ir/replace_global_vars.h>

#include <vector>

namespace tvm {
namespace transform {

IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements) {
IRModule ReplaceGlobalVars(IRModule mod, Map<GlobalVar, GlobalVar> replacements) {
if (replacements.empty()) {
return mod;
}

std::vector<GlobalVar> to_remove;
IRModule updates;

Expand Down Expand Up @@ -57,7 +61,38 @@ IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements)
return mod;
}

TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVar").set_body_typed(ReplaceGlobalVar);
TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars);

IRModule ModuleReplaceGlobalVars(
IRModule mod, Map<Variant<String, GlobalVar>, Variant<String, GlobalVar>> replacements) {
Map<GlobalVar, GlobalVar> gvar_replacements;
for (const auto& [before, after] : replacements) {
GlobalVar gvar_before;
if (auto gvar = before.as<GlobalVar>()) {
gvar_before = gvar.value();
} else if (auto str = before.as<String>()) {
gvar_before = mod->GetGlobalVar(str.value());
} else {
LOG(FATAL) << "Variant<String,GlobalVar> must contain either String or GlobalVar";
}

GlobalVar gvar_after;
if (auto gvar = after.as<GlobalVar>()) {
gvar_after = gvar.value();
} else if (auto str = after.as<String>()) {
gvar_after = gvar_before;
gvar_after.CopyOnWrite()->name_hint = str.value();
} else {
LOG(FATAL) << "Variant<String,GlobalVar> must contain either String or GlobalVar";
}

gvar_replacements.Set(gvar_before, gvar_after);
}

return ReplaceGlobalVars(mod, gvar_replacements);
}

TVM_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars);

} // namespace transform
} // namespace tvm
4 changes: 2 additions & 2 deletions src/relax/transform/attach_global_symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
*/

#include <tvm/ir/module.h>
#include <tvm/ir/replace_global_var.h>
#include <tvm/ir/replace_global_vars.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/function.h>
Expand Down Expand Up @@ -72,7 +72,7 @@ Pass AttachGlobalSymbol() {
mod.CopyOnWrite()->Update(updates);

if (gvar_updates.size()) {
mod = tvm::transform::ReplaceGlobalVar(mod, gvar_updates);
mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates);
}
}
return mod;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

/*!
*
* \file src/relax/transform/replace_global_var.cc
* \file src/relax/transform/replace_global_vars.cc
*
* \brief GlobalVar replacement across IR types
*/

#include <tvm/ir/analysis.h>
#include <tvm/ir/replace_global_var.h>
#include <tvm/ir/replace_global_vars.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/expr_functor.h>
Expand Down Expand Up @@ -53,7 +53,24 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
.set_dispatch<relax::FunctionNode>([](const ObjectRef& func,
Map<GlobalVar, GlobalVar> replacements) -> BaseFunc {
Mutator mutator(replacements);
return Downcast<BaseFunc>(mutator(Downcast<Function>(func)));
auto new_func = Downcast<Function>(mutator(Downcast<Function>(func)));

// If the function is externally exposed, and is being replaced
// by a GlobalVar with a new name, then the function's
// kGlobalSymbol must be updated to match.
if (auto opt = new_func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
auto name = opt.value();
for (const auto& [before, after] : replacements) {
if (before->name_hint == name) {
if (after->name_hint != name) {
new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, after->name_hint);
}
break;
}
}
}

return new_func;
});

TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

/*!
*
* \file src/tir/transforms/replace_global_var.cc
* \file src/tir/transforms/replace_global_vars.cc
*
* \brief GlobalVar replacement across IR types
*/

#include <tvm/ir/replace_global_var.h>
#include <tvm/ir/replace_global_vars.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>

Expand Down Expand Up @@ -61,6 +61,22 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
if (!new_body.same_as(func->body)) {
func.CopyOnWrite()->body = new_body;
}

// If the function is externally exposed, and is being replaced
// by a GlobalVar with a new name, then the function's
// kGlobalSymbol must be updated to match.
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
auto name = opt.value();
for (const auto& [before, after] : replacements) {
if (before->name_hint == name) {
if (after->name_hint != name) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, after->name_hint);
}
break;
}
}
}

return func;
});

Expand Down
Loading