Skip to content

Commit faa6628

Browse files
authored
[Relax] Additional unit tests for RemoveUnusedParameters (#16574)
* [Relax] Additional unit tests for RemoveUnusedParameters Verifying behavior for subroutines that receive `R.Prim` or `R.Shape` parameters, if the symbolic variables defined by those parameters are already defined by another parameter. * Typo fix
1 parent 33a6f75 commit faa6628

File tree

1 file changed

+106
-3
lines changed

1 file changed

+106
-3
lines changed

tests/python/relax/test_transform_remove_unused_parameters.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@ class BaseCompare(tvm.testing.CompareBeforeAfter):
2424
transform = tvm.relax.transform.RemoveUnusedParameters()
2525

2626

27-
class TestSimple(BaseCompare):
27+
class TestRemoveUnusedRelaxParameter(BaseCompare):
28+
"""A relax parameter may be removed
29+
30+
This is only allowed for internal function calls, where all
31+
callsites can be updated. For externally-exposed functions, the
32+
signature may not be modified.
33+
"""
34+
2835
@I.ir_module
2936
class Before:
3037
@R.function
@@ -46,7 +53,15 @@ def func(A: R.Tensor) -> R.Tensor:
4653
return A
4754

4855

49-
class TestSymbolicVariables(BaseCompare):
56+
class TestReplaceSymbolicVariables(BaseCompare):
57+
"""If a parameter is only required for its symbolic variables, provide them directly
58+
59+
The relax parameter `A` isn't used by the subroutine. However,
60+
its shape defines the symbolic variables `m` and `n`. When
61+
removing the `R.Tensor` argument, we may need to provide
62+
additional parameters to define the symbolic variables.
63+
"""
64+
5065
@I.ir_module
5166
class Before:
5267
@R.function
@@ -78,7 +93,12 @@ def func(
7893

7994

8095
class TestNoExtraSymbolicVariables(BaseCompare):
81-
"""Don't add symbolic variables if they can be inferred."""
96+
"""Don't add symbolic variables if they can be inferred.
97+
98+
Even though some cases require adding new parameters to provide
99+
symbolic variables, not every symbolic variable requires a
100+
distinct parameter.
101+
"""
82102

83103
@I.ir_module
84104
class Before:
@@ -97,5 +117,88 @@ def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
97117
Expected = Before
98118

99119

120+
class TestRemoveExtraPrimVariables(BaseCompare):
121+
"""Remove parameters that only serve to define existing symbolic variables
122+
123+
If a `R.Prim` parameter provies a definition of a symbolic
124+
variable, but that symbolic variable can be determined from a
125+
different parameter, then the `R.Prim` parameter can be removed.
126+
"""
127+
128+
@I.ir_module
129+
class Before:
130+
@R.function
131+
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
132+
m = T.int64()
133+
n = T.int64()
134+
return Before.func(A, R.prim_value(m), R.prim_value(n))
135+
136+
@R.function(private=True)
137+
def func(
138+
A: R.Tensor(["m", "n"], "float32"), _m: R.Prim(value="m"), _n: R.Prim(value="n")
139+
) -> R.Tensor(["m", "n"], "float32"):
140+
m = T.int64()
141+
n = T.int64()
142+
zeros = R.zeros(R.shape([m, n]), dtype="float32")
143+
out = R.add(A, zeros)
144+
return out
145+
146+
@I.ir_module
147+
class Expected:
148+
@R.function
149+
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
150+
return Expected.func(A)
151+
152+
@R.function(private=True)
153+
def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
154+
m = T.int64()
155+
n = T.int64()
156+
zeros = R.zeros(R.shape([m, n]), dtype="float32")
157+
out = R.add(A, zeros)
158+
return out
159+
160+
161+
class TestRemoveExtraShapeVariables(BaseCompare):
162+
"""Remove parameters that only serve to define existing symbolic variables
163+
164+
If a `R.Shape` parameter provides a definition of a symbolic
165+
variable, but that symbolic variable can be determined from a
166+
different parameter, then the `R.Shape` parameter can be removed.
167+
"""
168+
169+
@I.ir_module
170+
class Before:
171+
@R.function
172+
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
173+
m = T.int64()
174+
n = T.int64()
175+
return Before.func(A, R.shape([m, n]))
176+
177+
@R.function(private=True)
178+
def func(
179+
A: R.Tensor(["m", "n"], "float32"),
180+
_: R.Shape(["m", "n"]),
181+
) -> R.Tensor(["m", "n"], "float32"):
182+
m = T.int64()
183+
n = T.int64()
184+
zeros = R.zeros(R.shape([m, n]), dtype="float32")
185+
out = R.add(A, zeros)
186+
return out
187+
188+
@I.ir_module
189+
class Expected:
190+
@R.function
191+
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
192+
return Expected.func(A)
193+
194+
@R.function(private=True)
195+
def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
196+
m = T.int64()
197+
n = T.int64()
198+
zeros = R.zeros(R.shape([m, n]), dtype="float32")
199+
out = R.add(A, zeros)
200+
return out
201+
202+
100203
if __name__ == "__main__":
101204
tvm.testing.main()

0 commit comments

Comments
 (0)