Skip to content

Commit 94e83f2

Browse files
authored
[Unity][VM] Recursively visit match bindings in VMShapeLowerMutator (#16583)
Prior to this commit, the `MatchBinding` visitor in `VMShapeLowerMutator`. If the RHS of the `MatchBinding` is a `ShapeExpr` that uses symbolic variables, that RHS must be visited in order to have the symbolic variable updated.
1 parent 9ed9f7a commit 94e83f2

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

src/relax/backend/vm/vm_shape_lower.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ class VMShapeLowerMutator
419419

420420
// These checks are emitted as extra, in codegen
421421
// match-cast is simply ignored and treated as a normal binding.
422-
builder_->EmitNormalized(GetRef<MatchCast>(binding));
422+
ExprMutator::VisitBinding_(binding);
423423
}
424424

425425
// Do not override shape in struct info fields

tests/python/relax/test_backend_transform_shape_lower.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,5 +731,83 @@ def main(
731731
assert_structural_equal(after, expected)
732732

733733

734+
def test_update_symbolic_vars_in_match_cast_rhs():
735+
"""Symbolic variables may be used on the RHS of match_cast"""
736+
737+
@I.ir_module
738+
class Before:
739+
@R.function
740+
def main(
741+
arg_prim_value: R.Prim(value="n"),
742+
):
743+
R.func_attr({"relax.force_pure": 1})
744+
n = T.int64()
745+
shape = R.shape([n])
746+
m = T.int64()
747+
_ = R.match_cast(shape, R.Shape([m]))
748+
return R.prim_value(m)
749+
750+
@I.ir_module
751+
class Expected:
752+
@R.function
753+
def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"):
754+
R.func_attr({"relax.force_pure": 1})
755+
n = T.int64()
756+
757+
shape_heap = R.call_builtin_with_ctx(
758+
"vm.builtin.alloc_shape_heap",
759+
[2],
760+
sinfo_args=(R.Tensor(dtype="int64", ndim=1),),
761+
)
762+
_ = R.call_packed(
763+
"vm.builtin.check_prim_value_info",
764+
arg_prim_value,
765+
R.dtype("int64"),
766+
"",
767+
sinfo_args=[R.Tuple],
768+
)
769+
_ = R.call_packed(
770+
"vm.builtin.match_prim_value",
771+
arg_prim_value,
772+
shape_heap,
773+
MatchShapeCode.STORE_TO_HEAP,
774+
0,
775+
"",
776+
sinfo_args=[R.Tuple],
777+
)
778+
shape = R.call_packed(
779+
"vm.builtin.make_shape",
780+
shape_heap,
781+
1,
782+
MakeShapeCode.LOAD_SHAPE,
783+
0,
784+
sinfo_args=[R.Shape(ndim=1)],
785+
)
786+
_ = R.call_packed(
787+
"vm.builtin.match_shape",
788+
shape,
789+
shape_heap,
790+
1,
791+
MatchShapeCode.STORE_TO_HEAP,
792+
1,
793+
"",
794+
sinfo_args=[R.Tuple],
795+
)
796+
797+
m = T.int64()
798+
_ = R.match_cast(shape, R.Shape([m]))
799+
gv = R.call_packed(
800+
"vm.builtin.make_prim_value",
801+
shape_heap,
802+
MakeShapeCode.LOAD_SHAPE,
803+
1,
804+
sinfo_args=[R.Prim(value=m)],
805+
)
806+
return gv
807+
808+
After = relax.transform.VMShapeLower(emit_err_ctx=False)(Before)
809+
assert_structural_equal(Expected, After)
810+
811+
734812
if __name__ == "__main__":
735813
tvm.testing.main()

0 commit comments

Comments
 (0)