Skip to content

Commit 3f615dc

Browse files
authored
[Bugfix][Relax] BlockBuilder may not assume unique input functions (#16805)
Prior to this commit, the implementation of `relax::BlockBuilder::AddFunction` implicitly assumed that the input `IRModule` does not contain duplicate copies of the same function. This commit updates the implementation, removing the reliance on this assumption. This commit resolves the error by tracking all `GlobalVar` that map to the same function, rather than an just one. A well-formed IRModule may contain duplicate function definitions. This is rare, as most functions can be disambiguated by the the function attribute `tvm::attr::kGlobalSymbol`. However, private functions do not have this attribute, and a well-formed IRModule may contain multiple copies of the same function. The regression test added in this PR calls `BlockBuilder::UpdateFunc` and `BlockBuilder::AddFunc` in a specific order to reproduce this issue. In practice, this failure was sporadic, depending on the order in which a transformation pass visited functions in a module. This was first observed in `VMShapeLower`, with sporadic errors depending on the order of iteration over `mod->functions`.
1 parent b5fda2d commit 3f615dc

File tree

2 files changed

+95
-10
lines changed

2 files changed

+95
-10
lines changed

src/relax/ir/block_builder.cc

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
#include <memory>
3737
#include <unordered_map>
38+
#include <unordered_set>
3839
#include <vector>
3940

4041
#include "../../node/ndarray_hash_equal.h"
@@ -102,32 +103,49 @@ class BlockBuilderImpl : public BlockBuilderNode {
102103

103104
context_mod_->Add(gvar, func);
104105

105-
ctx_func_dedup_map_->emplace(func, gvar);
106+
(*ctx_func_dedup_map_)[func].insert(gvar);
106107
return gvar;
107108
} else {
108-
return it->second;
109+
ICHECK(it->second.size()) << "Values contained in de-duplication map must be non-empty sets, "
110+
<< "but found an empty set for function of type "
111+
<< func->GetTypeKey();
112+
// To provide deterministic results, return the GlobalVar that
113+
// comes first in lexicographic order.
114+
return *std::min_element(
115+
it->second.begin(), it->second.end(),
116+
[](const GlobalVar& a, const GlobalVar& b) { return a->name_hint < b->name_hint; });
109117
}
110118
}
111119

112120
void UpdateFunction(const GlobalVar& gv, BaseFunc function) final {
113121
context_mod_.CopyOnWrite();
114122

115-
// invalidate old dedup map
123+
// Remove function from the de-duplication map.
116124
if (ctx_func_dedup_map_ != nullptr) {
117125
auto it = context_mod_->functions.find(gv);
118126
if (it != context_mod_->functions.end()) {
119127
BaseFunc old_func = (*it).second;
120128
auto ptr = ctx_func_dedup_map_->find(old_func);
121-
ICHECK(ptr != ctx_func_dedup_map_->end());
122-
ctx_func_dedup_map_->erase(ptr);
129+
ICHECK(ptr != ctx_func_dedup_map_->end())
130+
<< "BlockBuilder::UpdateFunction is updating " << gv
131+
<< ", which appears in the BlockBuilder's context_mod_, "
132+
<< "but does not appear in the de-duplication map";
133+
ICHECK(ptr->second.count(gv))
134+
<< "BlockBuilder::UpdateFunction is updating " << gv
135+
<< ", but the de-duplication map for the previous value of this function "
136+
<< "does not include " << gv;
137+
ptr->second.erase(gv);
138+
if (ptr->second.empty()) {
139+
ctx_func_dedup_map_->erase(ptr);
140+
}
123141
}
124142
}
125143

126144
context_mod_->Update(gv, function);
127145

128146
// add new dedup map item.
129147
if (ctx_func_dedup_map_ != nullptr) {
130-
ctx_func_dedup_map_->emplace(function, gv);
148+
(*ctx_func_dedup_map_)[function].insert(gv);
131149
}
132150
}
133151

@@ -399,7 +417,8 @@ class BlockBuilderImpl : public BlockBuilderNode {
399417
* We use a custom hash to avoid hashing constants that may be bound to each BaseFunc.
400418
*/
401419
std::unique_ptr<
402-
std::unordered_map<BaseFunc, GlobalVar, StructuralHashIgnoreNDarray, StructuralEqual>>
420+
std::unordered_map<BaseFunc, std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual>,
421+
StructuralHashIgnoreNDarray, StructuralEqual>>
403422
ctx_func_dedup_map_ = nullptr;
404423

405424
/*!
@@ -408,11 +427,12 @@ class BlockBuilderImpl : public BlockBuilderNode {
408427
void LazyInitCtxFuncDedupMap() {
409428
if (ctx_func_dedup_map_ != nullptr) return;
410429
ctx_func_dedup_map_ = std::make_unique<
411-
std::unordered_map<BaseFunc, GlobalVar, StructuralHashIgnoreNDarray, StructuralEqual>>();
430+
std::unordered_map<BaseFunc, std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual>,
431+
StructuralHashIgnoreNDarray, StructuralEqual>>();
412432
for (const auto& kv : context_mod_->functions) {
413433
const GlobalVar gv = kv.first;
414434
const BaseFunc func = kv.second;
415-
ctx_func_dedup_map_->emplace(func, gv);
435+
(*ctx_func_dedup_map_)[func].insert(gv);
416436
}
417437
}
418438

tests/python/relax/test_blockbuilder_core.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tvm import relax as rx, relay
2525
from tvm.ir.base import assert_structural_equal
2626
from tvm.relax import ExternFunc
27-
from tvm.script import relax as R, tir as T
27+
from tvm.script import ir as I, relax as R, tir as T
2828
from tvm.tir.function import PrimFunc
2929

3030

@@ -925,5 +925,70 @@ def test_error_when_unwrapping_dataflowvar():
925925
bb.emit_func_output(out)
926926

927927

928+
def test_deduplication_when_input_contains_duplicates():
929+
"""De-duplication of IRModules
930+
931+
A well-formed IRModule may contain duplicate function definitions.
932+
This is rare, as most functions can be disambiguated by the the
933+
function attribute `tvm::attr::kGlobalSymbol`. However, private
934+
functions do not have this attribute, and a well-formed IRModule
935+
may contain multiple copies of the same function.
936+
937+
This is a regression test. Previous implementation de-duplicated
938+
using a `Dict[Function, GlobalVar]`, which has the failure mode
939+
shown below. This was resolved by de-duplicating using a
940+
`Dict[Function, Set[GlobalVar]]` instead.
941+
942+
"""
943+
944+
@I.ir_module
945+
class Module:
946+
@R.function
947+
def main(A: R.Tensor):
948+
B = Module.subroutine_a(A)
949+
C = Module.subroutine_b(B)
950+
return C
951+
952+
@R.function(private=True)
953+
def subroutine_a(arg: R.Tensor) -> R.Tensor:
954+
return R.add(arg, arg)
955+
956+
@R.function(private=True)
957+
def subroutine_b(arg: R.Tensor) -> R.Tensor:
958+
return R.add(arg, arg)
959+
960+
@R.function(private=True)
961+
def subroutine_c(arg: R.Tensor) -> R.Tensor:
962+
return R.multiply(arg, arg)
963+
964+
# This test case is only valid when the two subroutines are
965+
# structurally equal, and therefore allowed to be de-duplicated by
966+
# the BlockBuilder.
967+
tvm.ir.assert_structural_equal(Module["subroutine_a"], Module["subroutine_b"])
968+
969+
gvar_a = Module.get_global_var("subroutine_a")
970+
gvar_b = Module.get_global_var("subroutine_b")
971+
subroutine_c = Module["subroutine_c"]
972+
973+
bb = rx.BlockBuilder(Module)
974+
975+
# Add a function to the module. What we add doesn't matter, as
976+
# this is only to initialize the de-duplication map.
977+
bb.add_func(subroutine_c, "_unused")
978+
# The deduplication table now maps `subroutine_ab` to either
979+
# `gvar_a` or `gvar_b`.
980+
981+
# Update gvar_a.
982+
bb.update_func(gvar_a, subroutine_c)
983+
# The deduplication map no longer has an entry for
984+
# `subroutine_ab`.
985+
986+
# Update gvar_b. The deduplication map is present (because we
987+
# called `add_func`), but doesn't contain an entry for
988+
# `subroutine_ab` (because it was just removed). This throws an
989+
# error.
990+
bb.update_func(gvar_b, subroutine_c)
991+
992+
928993
if __name__ == "__main__":
929994
tvm.testing.main()

0 commit comments

Comments
 (0)