Skip to content

Commit 7bd738a

Browse files
authored
[Relax] Implement Rewriter class for pattern-rewrite (#17149)
* [TVMScript][Bugfix] Normalize relax::If with function's TIR var Prior to this commit, the branches of `relax::If` were normalized using `EraseToWellDefinedInScope`, using a fresh variable scope. While this had the intended behavior of preventing variables defined in a single branch from being usable outside of the conditional, it also caused the conditional's branches to treat function-scope symbolic variables as if they were undefined. This commit updates the `tvm::relax::Normalizer` so that `relax::If` is normalized within an inherited scope. This preserves the previous behavior for symbolic variables defined within a branch, but allows shapes within a branch to use symbolic variables defined outside of the branch. * [Relax] Canonicalize known symbolic shapes in Relax expressions Prior to this commit, known constants in Relax functions would be inlined by the `CanonicalizeBindings` pass, but only if they appeared as Relax expressions (e.g. `R.const` or `R.prim_value`). Known constants that appeared as TIR variables (e.g. symbolic shapes) would be kept as dynamic parameters, even if they were known at compile time. This commit updates the `CanonicalizeBindings` pass to identify known values of symbolic shapes, and to use these known values in shape expressions. * [Relax][Refactor] Reorganize pattern-matching A follow-up to #16730. Now that the implementations for `rewrite_call` and `rewrite_bindings` are in separate classes, they can be further split out into separate files. * [Relax][Refactor] Implement Rewriter class for pattern-rewrite Prior to this commit, the pattern to be matched and the rewrite to be performed were provided as separate arguments. This commit introduces a new class `ExprRewriter`, which contains both parts. This abstraction will make it easier to combine multiple different rewrite rules, applying them in a single pass. * lint fixes * Remove unnecessary change which broke a unit test * lint fix for import order * Add docstrings * lint fix * Lint fix * lint fixes * lint fix * Update based on review comments * Add test case for matching against arbitrary dtype * Fix breakage in unit tests One unit test that had been relying on invalid shape propagation. Another unit test that required constructed an ill-formed output to test against. * Updated base class name from ExprRewriter to PatternMatchingRewriter * lint fix
1 parent cc8afdb commit 7bd738a

File tree

24 files changed

+4142
-751
lines changed

24 files changed

+4142
-751
lines changed

include/tvm/relax/block_builder.h

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,47 @@ class BlockBuilderNode : public Object {
133133
* \brief Begin a new scope, with optional parameters that
134134
* are visible within the scope.
135135
*
136+
* Symbolic variables from the parent scope are not available.
137+
*
136138
* \param params Parameters that are visible within the scope.
137139
*
138140
* \note This function should be called when new scope is introduced
139-
* (function, seq) to properly track the variable availability
140-
* and help the best effort deduction.
141+
* (e.g. function bodies) to properly track the variable
142+
* availability and help the best effort deduction.
141143
*
142144
* \sa EndScope
143145
*/
144146
virtual void BeginScope(Optional<Array<Var>> params) = 0;
145147

148+
/*!
149+
* \brief Begin a new scope, which inherits visible parameters from
150+
* its parent scope.
151+
*
152+
* Symbolic variables from the parent scope are available.
153+
*
154+
* \note This function should be called when an inner scope is
155+
* introduced (e.g. conditional branches) to properly track
156+
* the variable availability and help the best effort
157+
* deduction.
158+
*
159+
* \sa EndScope
160+
*/
161+
virtual void BeginInnerScope() = 0;
162+
163+
/*!
164+
* \brief Append a definition to the current scope.
165+
*
166+
* \param var A variable within the current scope.
167+
*
168+
* \note This function should be called when a new variable is
169+
* defined that may impact struct inference (e.g. MatchCast)
170+
* to properly track the variable availability and help the
171+
* best effort deduction.
172+
*
173+
* \sa EndScope
174+
*/
175+
virtual void AddDefinitionToScope(Var var) = 0;
176+
146177
/*! \brief End the previously defined scope. */
147178
virtual void EndScope() = 0;
148179

include/tvm/relax/expr_functor.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,10 @@ class ExprMutator : public ExprMutatorBase {
494494
void ReEmitBinding(const VarBindingNode* binding, Expr new_value);
495495

496496
/*!
497-
* \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If.
497+
* \brief Rewrite the expr with a new scope, used in a Function's body.
498+
*
499+
* Visit an expression that may neither access variables from the
500+
* current scope, nor may export definitions into the current scope.
498501
*
499502
* \param body_expr The body to be visited.
500503
* \param params Optional parameters that are visible within the scope.
@@ -504,6 +507,22 @@ class ExprMutator : public ExprMutatorBase {
504507
*/
505508
Expr VisitWithNewScope(const Expr& body_expr, Optional<Array<Var>> params = NullOpt);
506509

510+
/*!
511+
* \brief Rewrite the expr with a new scope, used in the branches of If.
512+
*
513+
* Visit an expression that may access variables from the current
514+
* scope, but may not export definitions into the current scope.
515+
*
516+
* \param body_expr The body to be visited.
517+
*
518+
* \return The expr after visiting.
519+
*
520+
* \sa VisitWithNewScope
521+
*
522+
* \note The body_expr must be an SeqExpr in the normal form.
523+
*/
524+
Expr VisitWithInnerScope(const Expr& body_expr);
525+
507526
/*!
508527
* \brief Look up the value bound to a variable.
509528
* \param var The var to be looked up.

include/tvm/script/ir_builder/relax/frame.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class FunctionFrameNode : public SeqExprFrameNode {
122122
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode);
123123

124124
public:
125+
void EnterWithScope() final;
125126
void ExitWithScope() final;
126127
};
127128

python/tvm/relax/dpl/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,10 @@
1919

2020
from .pattern import *
2121
from .context import *
22-
from .rewrite import rewrite_call, rewrite_bindings
22+
from .rewrite import (
23+
rewrite_call,
24+
rewrite_bindings,
25+
PatternMatchingRewriter,
26+
ExprPatternRewriter,
27+
OrRewriter,
28+
)

python/tvm/relax/dpl/rewrite.py

Lines changed: 183 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,196 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""APIs for pattern-based rewriting."""
18-
from typing import Dict, Callable
18+
19+
from typing import Dict, Callable, Union
20+
21+
from tvm.ir import IRModule
22+
from tvm.runtime import Object
23+
from tvm._ffi import register_object
24+
1925
from .pattern import DFPattern
2026
from .context import PatternContext
21-
2227
from ..expr import Expr, Function, Var
2328
from . import _ffi as ffi
2429

2530

31+
@register_object("relax.dpl.PatternMatchingRewriter")
32+
class PatternMatchingRewriter(Object):
33+
"""A pattern-matching rewriter for Relax"""
34+
35+
@staticmethod
36+
def from_pattern(
37+
pattern: DFPattern,
38+
func: Callable[[Expr, Dict[DFPattern, Expr]], Expr],
39+
) -> "PatternMatchingRewriter":
40+
"""Construct from a pattern and rewriter-function
41+
42+
The replacements performed by the rewriter will be equivalent
43+
to using the `pattern` and `func` as arguments to
44+
`rewrite_call`.
45+
46+
Parameters
47+
----------
48+
pattern: DFPattern
49+
50+
The pattern to be matched against.
51+
52+
func: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
53+
54+
A function that returns the rewritten expression. See
55+
`rewrite_call` for details and examples.
56+
57+
58+
Returns
59+
-------
60+
rewriter_obj: PatternMatchingRewriter
61+
62+
The rewriter object
63+
64+
"""
65+
return ffi.PatternMatchingRewriterFromPattern(
66+
pattern,
67+
func,
68+
) # type: ignore
69+
70+
@staticmethod
71+
def from_module(mod: IRModule) -> "PatternMatchingRewriter":
72+
"""Construct a rewriter from an IRModule
73+
74+
The IRModule must have two publicly-exposed functions,
75+
`pattern` and `replacement`, where `pattern` and `replacement`
76+
have the same function signature, as shown in the example
77+
below.
78+
79+
.. code-block:: python
80+
81+
@I.ir_module
82+
class RewriteAddIntoMultiply:
83+
@R.function
84+
def pattern(A: R.Tensor):
85+
B = A + A
86+
return B
87+
88+
@R.function
89+
def replacement(A: R.Tensor):
90+
B = A * 2
91+
return B
92+
93+
rewriter = PatternMatchingRewriter.from_module(RewriteAddIntoMultiply)
94+
rewritten_ir_module = rewriter(ir_module)
95+
96+
To support the common case of defining an IRModule with
97+
TVMScript, then immediately turning it into a rewriter, the
98+
`@R.rewriter` annotation can be used.
99+
100+
.. code-block:: python
101+
102+
@R.rewriter
103+
class RewriteAddIntoMultiply:
104+
@R.function
105+
def pattern(A: R.Tensor):
106+
B = A + A
107+
return B
108+
109+
@R.function
110+
def replacement(A: R.Tensor):
111+
B = A * 2
112+
return B
113+
114+
rewritten_ir_module = RewriteAddIntoMultiply(ir_module)
115+
116+
Parameters
117+
----------
118+
mod: IRModule
119+
120+
A module with `pattern` and `replacement` functions,
121+
defining a rewrite rule.
122+
123+
124+
Returns
125+
-------
126+
rewriter_obj: PatternMatchingRewriter
127+
128+
The rewriter object
129+
130+
"""
131+
return ffi.PatternMatchingRewriterFromModule(mod) # type: ignore
132+
133+
def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]:
134+
"""Apply the rewriter
135+
136+
Parameters
137+
----------
138+
obj: Union[Expr, IRModule])
139+
140+
The object to be rewritten. May be applied to either a
141+
relax expression, or an IRModule.
142+
143+
Returns
144+
-------
145+
updated: Union[Expr, IRModule]
146+
147+
The rewritten object
148+
149+
"""
150+
return ffi.PatternMatchingRewriterApply(self, obj)
151+
152+
def __or__(self, other: "PatternMatchingRewriter") -> "PatternMatchingRewriter":
153+
"""Compose two rewriters
154+
155+
Composing two rewrite rules together allows them to be applied
156+
in a single Relax-level transformation.
157+
158+
Parameters
159+
----------
160+
other: PatternMatchingRewriter
161+
162+
Another rewrite rule
163+
164+
Returns
165+
-------
166+
PatternMatchingRewriter
167+
168+
A rewriter that will apply either rewrite pattern
169+
170+
"""
171+
return OrRewriter(self, other)
172+
173+
174+
@register_object("relax.dpl.ExprPatternRewriter")
175+
class ExprPatternRewriter(PatternMatchingRewriter):
176+
def __init__(self, pattern, func):
177+
self.__init_handle_by_constructor__(
178+
ffi.PatternRewriter,
179+
pattern,
180+
func,
181+
) # type: ignore
182+
183+
184+
@register_object("relax.dpl.OrRewriter")
185+
class OrRewriter(PatternMatchingRewriter):
186+
def __init__(self, lhs, rhs):
187+
self.__init_handle_by_constructor__(
188+
ffi.OrRewriter,
189+
lhs,
190+
rhs,
191+
) # type: ignore
192+
193+
194+
@register_object("relax.dpl.TupleRewriter")
195+
class TupleRewriter(PatternMatchingRewriter):
196+
def __init__(self, patterns, func):
197+
self.__init_handle_by_constructor__(
198+
ffi.TupleRewriter,
199+
patterns,
200+
func,
201+
) # type: ignore
202+
203+
26204
def rewrite_call(
27-
pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function
205+
pattern: DFPattern,
206+
rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr],
207+
func: Function,
28208
) -> Function:
29209
"""
30210
Rewrite a function with the given pattern and the rewriter function.

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
import builtins
2121
import functools
2222
import inspect
23-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
23+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type
2424

2525
import tvm
2626
from tvm import DataType, relax
27-
from tvm.ir import PrimExpr, VDevice
27+
from tvm.ir import PrimExpr, VDevice, IRModule
2828
from tvm.relax import (
2929
Call,
3030
Expr,
@@ -35,6 +35,7 @@
3535
VarBinding,
3636
const,
3737
)
38+
from tvm.relax.dpl import PatternMatchingRewriter
3839

3940
############################### Operators ###############################
4041
from tvm.relax.op import (
@@ -306,6 +307,48 @@ def func_ret_value(value: Expr) -> None:
306307
return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member
307308

308309

310+
def rewriter(rewriter_mod: Union[IRModule, Type]) -> PatternMatchingRewriter:
311+
"""Define a pattern-rewrite rule
312+
313+
The IRModule must have two publicly-exposed functions, `pattern`
314+
and `replacement`, where `pattern` and `replacement` have the same
315+
function signature.
316+
317+
.. code-block:: python
318+
319+
@R.rewriter
320+
class RewriteAddIntoMultiply:
321+
@R.function
322+
def pattern(A: R.Tensor):
323+
B = A + A
324+
return B
325+
326+
@R.function
327+
def replacement(A: R.Tensor):
328+
B = A * 2
329+
return B
330+
331+
Parameters
332+
----------
333+
rewriter_mod: Union[IRModule, Type]
334+
335+
Either an IRModule that defines a rewrite pattern, or a
336+
TVMScript class that can be parsed into an IRModule.
337+
338+
Returns
339+
-------
340+
rewriter: PatternMatchingRewriter
341+
342+
A rewriter object, which can be applied either to a Relax
343+
function or to an entire IRModule.
344+
345+
"""
346+
if not isinstance(rewriter_mod, IRModule):
347+
rewriter_mod = tvm.script.ir_module(rewriter_mod)
348+
349+
return PatternMatchingRewriter.from_module(rewriter_mod)
350+
351+
309352
############################# BindingBlock ##############################
310353

311354

@@ -765,6 +808,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
765808
"dequantize",
766809
"repeat",
767810
"reshape",
811+
"rewriter",
768812
"tensor_to_shape",
769813
"shape_to_tensor",
770814
"rocm",

0 commit comments

Comments
 (0)