@@ -24,7 +24,14 @@ class BaseCompare(tvm.testing.CompareBeforeAfter):
2424 transform = tvm .relax .transform .RemoveUnusedParameters ()
2525
2626
27- class TestSimple (BaseCompare ):
27+ class TestRemoveUnusedRelaxParameter (BaseCompare ):
28+ """A relax parameter may be removed
29+
30+ This is only allowed for internal function calls, where all
31+ callsites can be updated. For externally-exposed functions, the
32+ signature may not be modified.
33+ """
34+
2835 @I .ir_module
2936 class Before :
3037 @R .function
@@ -46,7 +53,15 @@ def func(A: R.Tensor) -> R.Tensor:
4653 return A
4754
4855
49- class TestSymbolicVariables (BaseCompare ):
56+ class TestReplaceSymbolicVariables (BaseCompare ):
57+ """If a parameter is only required for its symbolic variables, provide them directly
58+
59+ The relax parameter `A` isn't used by the subroutine. However,
60+ its shape defines the symbolic variables `m` and `n`. When
61+ removing the `R.Tensor` argument, we may need to provide
62+ additional parameters to define the symbolic variables.
63+ """
64+
5065 @I .ir_module
5166 class Before :
5267 @R .function
@@ -78,7 +93,12 @@ def func(
7893
7994
8095class TestNoExtraSymbolicVariables (BaseCompare ):
81- """Don't add symbolic variables if they can be inferred."""
96+ """Don't add symbolic variables if they can be inferred.
97+
98+ Even though some cases require adding new parameters to provide
99+ symbolic variables, not every symbolic variable requires a
100+ distinct parameter.
101+ """
82102
83103 @I .ir_module
84104 class Before :
@@ -97,5 +117,88 @@ def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
97117 Expected = Before
98118
99119
120+ class TestRemoveExtraPrimVariables (BaseCompare ):
121+ """Remove parameters that only serve to define existing symbolic variables
122+
123+ If a `R.Prim` parameter provies a definition of a symbolic
124+ variable, but that symbolic variable can be determined from a
125+ different parameter, then the `R.Prim` parameter can be removed.
126+ """
127+
128+ @I .ir_module
129+ class Before :
130+ @R .function
131+ def main (A : R .Tensor (["m" , "n" ], "float32" )) -> R .Tensor (["m" , "n" ], "float32" ):
132+ m = T .int64 ()
133+ n = T .int64 ()
134+ return Before .func (A , R .prim_value (m ), R .prim_value (n ))
135+
136+ @R .function (private = True )
137+ def func (
138+ A : R .Tensor (["m" , "n" ], "float32" ), _m : R .Prim (value = "m" ), _n : R .Prim (value = "n" )
139+ ) -> R .Tensor (["m" , "n" ], "float32" ):
140+ m = T .int64 ()
141+ n = T .int64 ()
142+ zeros = R .zeros (R .shape ([m , n ]), dtype = "float32" )
143+ out = R .add (A , zeros )
144+ return out
145+
146+ @I .ir_module
147+ class Expected :
148+ @R .function
149+ def main (A : R .Tensor (["m" , "n" ], "float32" )) -> R .Tensor (["m" , "n" ], "float32" ):
150+ return Expected .func (A )
151+
152+ @R .function (private = True )
153+ def func (A : R .Tensor (["m" , "n" ], "float32" )) -> R .Tensor (["m" , "n" ], "float32" ):
154+ m = T .int64 ()
155+ n = T .int64 ()
156+ zeros = R .zeros (R .shape ([m , n ]), dtype = "float32" )
157+ out = R .add (A , zeros )
158+ return out
159+
160+
161+ class TestRemoveExtraShapeVariables (BaseCompare ):
162+ """Remove parameters that only serve to define existing symbolic variables
163+
164+ If a `R.Shape` parameter provides a definition of a symbolic
165+ variable, but that symbolic variable can be determined from a
166+ different parameter, then the `R.Shape` parameter can be removed.
167+ """
168+
169+ @I .ir_module
170+ class Before :
171+ @R .function
172+ def main (A : R .Tensor (["m" , "n" ], "float32" )) -> R .Tensor (["m" , "n" ], "float32" ):
173+ m = T .int64 ()
174+ n = T .int64 ()
175+ return Before .func (A , R .shape ([m , n ]))
176+
177+ @R .function (private = True )
178+ def func (
179+ A : R .Tensor (["m" , "n" ], "float32" ),
180+ _ : R .Shape (["m" , "n" ]),
181+ ) -> R .Tensor (["m" , "n" ], "float32" ):
182+ m = T .int64 ()
183+ n = T .int64 ()
184+ zeros = R .zeros (R .shape ([m , n ]), dtype = "float32" )
185+ out = R .add (A , zeros )
186+ return out
187+
188+ @I .ir_module
189+ class Expected :
190+ @R .function
191+ def main (A : R .Tensor (["m" , "n" ], "float32" )) -> R .Tensor (["m" , "n" ], "float32" ):
192+ return Expected .func (A )
193+
194+ @R .function (private = True )
195+ def func (A : R .Tensor (["m" , "n" ], "float32" )) -> R .Tensor (["m" , "n" ], "float32" ):
196+ m = T .int64 ()
197+ n = T .int64 ()
198+ zeros = R .zeros (R .shape ([m , n ]), dtype = "float32" )
199+ out = R .add (A , zeros )
200+ return out
201+
202+
100203if __name__ == "__main__" :
101204 tvm .testing .main ()
0 commit comments