Skip to content

Commit fac9520

Browse files
authored
[Unity][Transform] Raise error in FuseOpsByPattern for SSA violation (#16421)
Internally, `FuseOpsByPattern` makes a mapping from relax variables to the fused group containing that variable. If the input module violates SSA, this map may be ill-formed. While not strictly necessary for FuseOps to handle ill-formed inputs, checking it at this level provides better error handling than propagating it to downstream passes. This commit checks for ill-formed inputs that would produce invalid fused outputs and raises an error.
1 parent bde28ae commit fac9520

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

src/relax/transform/fuse_ops.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1286,7 +1286,14 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,
12861286
pattern->annotation_patterns,
12871287
pattern->check.value_or(nullptr), entry.second,
12881288
&arena, pattern->attrs_getter.value_or(nullptr));
1289-
group_map.insert(map.begin(), map.end());
1289+
for (const auto& [key, value] : map) {
1290+
CHECK(!group_map.count(key))
1291+
<< "ValueError: "
1292+
<< "IRModule is invalid. "
1293+
<< "The object " << GetRef<ObjectRef>(key) << " appears in multiple partitions, "
1294+
<< "which can occur when the IRModule was not single-site assignment";
1295+
group_map.insert({key, value});
1296+
}
12901297
}
12911298
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants);
12921299
}

tests/python/relax/test_transform_fuse_ops_by_pattern.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,5 +1109,26 @@ def test_multple_runs():
11091109
)
11101110

11111111

1112+
@pytest.mark.skip_well_formed_check_before_transform
1113+
def test_error_on_repeated_variable_definitions():
1114+
"""Raise error for SSA violations
1115+
1116+
Internally, `FuseOpsByPattern` makes a mapping from relax
1117+
variables to the fused group containing that variable. If the
1118+
input module violates SSA, this map may be ill-formed.
1119+
1120+
While not strictly necessary for FuseOps to handle ill-formed
1121+
inputs, checking it at this level provides better error handling
1122+
than propagating it to downstream passes.
1123+
"""
1124+
mod = Conv2dReLU.clone()
1125+
mod["copy"] = mod["main"].with_attr("global_symbol", "copy")
1126+
1127+
patterns = [("dnnl.conv2d_relu", conv2d_relu_pat)]
1128+
1129+
with pytest.raises(ValueError):
1130+
relax.transform.FuseOpsByPattern(patterns)(mod)
1131+
1132+
11121133
if __name__ == "__main__":
11131134
pytest.main([__file__])

0 commit comments

Comments
 (0)