Skip to content

Commit 11f3c44

Browse files
committed
[Transform] Remove R.Object parameters after LazyTranformParams
Prior to this commit, the output of `relax.transform.LazyTransformParams` would include all parameters that are not `R.Tensor`, in case they defined symbolic variables. However, this added too many unnecessary parameters, such as `R.Object` which cannot define symbolic variables. This commit updates `relax.transform.LazyTransformParams` to only retain `R.Prim` and `R.Shape` parameters, which can define symbolic variables.
1 parent cae1af6 commit 11f3c44

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

python/tvm/relax/transform/lazy_transform_params.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ def unpack_sinfo(sinfo):
216216
# direct iterate over the struct info annotation
217217
for param in func.params[num_input:]:
218218
for sinfo in unpack_sinfo(param.struct_info):
219-
if not isinstance(sinfo, relax.TensorStructInfo):
219+
if isinstance(sinfo, relax.PrimStructInfo) or isinstance(
220+
sinfo, relax.ShapeStructInfo
221+
):
220222
params.append(relax.Var("symbolic_var_holder", sinfo))
221223

222224
return relax.Function(

tests/python/relax/test_transform_lazy_transform_params.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,5 +804,25 @@ def transform_params(relax_rank: R.Prim(value="rank")):
804804
tvm.ir.assert_structural_equal(After, Expected)
805805

806806

807+
def test_params_without_tuple_with_symbolic_var():
808+
@I.ir_module
809+
class Before:
810+
@R.function
811+
def transform_params(A: R.Object):
812+
return (A,)
813+
814+
@I.ir_module
815+
class Expected:
816+
@R.function(pure=False)
817+
def transform_params():
818+
A = R.call_packed("get_item", R.prim_value(0), sinfo_args=[R.Object])
819+
A = R.match_cast(A, R.Object)
820+
821+
return (A,)
822+
823+
After = LazyTransformParams(fset_item=None)(Before)
824+
tvm.ir.assert_structural_equal(After, Expected)
825+
826+
807827
if __name__ == "__main__":
808828
tvm.testing.main()

0 commit comments

Comments
 (0)