Skip to content

Commit ec9e0a0

Browse files
authored
[Unity] Allow FLegalize to produce Relax operations (#15842)
* [Unity] Allow FLegalize to produce Relax operations Prior to this commit, a `FLegalize` function needed to produce an implementation that can be used as input by `relax.transform.AnnotateTIROpPattern`, and could not lower to other relax operations. This commit allows Relax operations to be included in the output of `FLegalize`, with the result being further legalized if required. * Maintain binding block type for nested SeqExpr * Avoid infinite recursion for strided slice on dynamic axis * Avoid duplicate variables when checking for re-legalization * Collect bindings to legalize during normalization
1 parent 8874f8b commit ec9e0a0

File tree

3 files changed

+141
-25
lines changed

3 files changed

+141
-25
lines changed

src/relax/ir/block_builder.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,17 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
612612
unchanged &= new_block.same_as(block);
613613
}
614614

615-
this->BeginBindingBlock();
615+
// Because the input may not be normalized, the SeqExpr may occur
616+
// nested within another SeqExpr. In that case, we want to use
617+
// whatever binding-block type the parent uses, so that we any
618+
// bindings collected into the prologue will be compatible with
619+
// the parent block.
620+
if (block_stack_.size() && CurrentBlockIsDataFlow()) {
621+
this->BeginDataflowBlock();
622+
} else {
623+
this->BeginBindingBlock();
624+
}
625+
616626
// the body may not be a leaf expression, so check for that
617627
Expr new_body = this->NormalizeArgument(op->body);
618628
unchanged &= new_body.same_as(op->body);

src/relax/transform/legalize_ops.cc

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ class LegalizeMutator : public ExprMutator {
5757
public:
5858
explicit LegalizeMutator(const IRModule& mod, const Optional<Map<String, PackedFunc>>& cmap,
5959
bool enable_warning)
60-
: ExprMutator(mod),
61-
mod_(std::move(mod)),
62-
cmap_(std::move(cmap)),
63-
enable_warning_(enable_warning) {}
60+
: ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) {
61+
if (cmap) {
62+
cmap_ = std::move(cmap.value());
63+
}
64+
}
6465

6566
IRModule Transform() {
6667
for (const auto& [gv, func] : mod_->functions) {
@@ -132,36 +133,67 @@ class LegalizeMutator : public ExprMutator {
132133
return visited_call;
133134
}
134135

135-
// Priority: customize > default.
136-
// Check if it has customize legalization registered.
137-
if (cmap_.defined() && cmap_.value().count(op->name)) {
138-
auto ret = cmap_.value()[op->name](this->builder_, visited_call);
139-
if (ret.IsObjectRef<Expr>() && WrapPureCondition(op, ret.AsObjectRef<Expr>())) {
140-
return WrapPureCall(Downcast<Call>(ret.AsObjectRef<Expr>()));
136+
FLegalize legalization_func;
137+
138+
if (auto opt_custom_legalize = cmap_.Get(op->name)) {
139+
// First choice, use a custom legalization function
140+
legalization_func = opt_custom_legalize.value();
141+
} else if (legalize_map.count(op)) {
142+
// Second choice, use a default legalization
143+
legalization_func = legalize_map[op];
144+
} else {
145+
// No legalization.
146+
if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
147+
op != call_pure_packed_op) {
148+
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
141149
}
142-
return ret;
150+
return visited_call;
143151
}
144-
// Check if it has default legalization registered.
145-
if (legalize_map.count(op)) {
146-
auto ret = legalize_map[op](this->builder_, visited_call);
147-
if (WrapPureCondition(op, ret)) {
148-
return WrapPureCall(Downcast<Call>(ret));
149-
}
150-
return ret;
152+
153+
// The legalization function may call `builder_->Emit()` as part
154+
// of its implementation. In that case, any operations it emits
155+
// must be caught such that they be checked for recursive
156+
// legalization. This is done by wrapping the legalized value in
157+
// a SeqExpr, which can first be visited, then unwrapped by the
158+
// normalization.
159+
if (builder_->CurrentBlockIsDataFlow()) {
160+
builder_->BeginDataflowBlock();
161+
} else {
162+
builder_->BeginBindingBlock();
151163
}
164+
Expr legalized = legalization_func(builder_, visited_call);
165+
legalized = builder_->Normalize(legalized);
152166

153-
// No legalization.
154-
if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
155-
op != call_pure_packed_op) {
156-
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
167+
BindingBlock prologue = builder_->EndBlock();
168+
for (const auto& binding : prologue->bindings) {
169+
VisitBinding(binding);
157170
}
158-
return visited_call;
171+
172+
if (WrapPureCondition(op, legalized)) {
173+
legalized = WrapPureCall(Downcast<Call>(legalized));
174+
}
175+
176+
// Legalization may have introduced additional operations that
177+
// must be legalized as well. For example, a user-custom
178+
// intrinsic whose legalization is implemented in terms of relax
179+
// intrinsics. The base case of the recursion occurs when no
180+
// additional legalization steps are found.
181+
//
182+
// Only perform recursive legalization when the legalization
183+
// function returned a modified expression, as some legalizations
184+
// return the original expression if they are unable to produce a
185+
// legalized version.
186+
if (!legalized.same_as(visited_call)) {
187+
legalized = VisitExpr(legalized);
188+
}
189+
190+
return legalized;
159191
}
160192

161193
/*! \brief The context IRModule. */
162194
IRModule mod_;
163195
/*! \brief The customized legalization function map. */
164-
Optional<Map<String, PackedFunc>> cmap_;
196+
Map<String, PackedFunc> cmap_;
165197
/*!
166198
* \brief A boolean value indicating if to print warnings for CallNode whose op's
167199
* legalization function is not registered.

tests/python/relax/test_transform_legalize_ops.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from tvm.script import relax as R, tir as T, ir as I
2525
import tvm.testing
2626

27+
import pytest
28+
2729

2830
def test_customize_legalize():
2931
# fmt: off
@@ -282,5 +284,77 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 8]):
282284
assert err_message.startswith("To legalize R.matmul")
283285

284286

287+
emit_legalization_through_builder = tvm.testing.parameter(
288+
by_dict={
289+
"return_relax_expr": False,
290+
"return_relax_var": True,
291+
}
292+
)
293+
294+
295+
@pytest.fixture
296+
def custom_op(emit_legalization_through_builder):
297+
op_name = "custom_op.matmul_bias_add"
298+
299+
def infer_struct_info(call: relax.Call, context):
300+
activations, weight, bias = call.args
301+
302+
matmul_call = relax.op.matmul(activations, weight)
303+
matmul_sinfo = tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
304+
matmul_call, context
305+
)
306+
307+
matmul_var = relax.Var("dummy_var", matmul_sinfo)
308+
add_call = matmul_var + bias
309+
add_sinfo = tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
310+
311+
return add_sinfo
312+
313+
def legalize(bb: relax.BlockBuilder, call: relax.Call):
314+
activations, weight, bias = call.args
315+
legalized = relax.op.matmul(activations, weight) + bias
316+
if emit_legalization_through_builder:
317+
legalized = bb.emit(legalized)
318+
return legalized
319+
320+
op_attrs = {
321+
"FInferStructInfo": infer_struct_info,
322+
"FLegalize": legalize,
323+
"FPurity": True,
324+
}
325+
326+
for key, value in op_attrs.items():
327+
tvm.ir.register_op_attr(op_name, key, value)
328+
329+
op = tvm.ir.Op.get(op_name)
330+
yield op
331+
332+
for key in op_attrs:
333+
op.reset_attr(key)
334+
335+
336+
def test_recursive_legalization(custom_op):
337+
"""Legalization of an operator may produce new operators requiring legalization"""
338+
339+
@I.ir_module
340+
class Before:
341+
@R.function
342+
def main(
343+
A: R.Tensor([16, 32, 64], "float32"),
344+
Weight: R.Tensor([64, 128], "float32"),
345+
Bias: R.Tensor([16, 32, 128], "float32"),
346+
):
347+
return relax.Call(custom_op, [A, Weight, Bias])
348+
349+
AfterFirstIter = LegalizeOps()(Before)
350+
AfterSecondIter = LegalizeOps()(AfterFirstIter)
351+
352+
# After LegalizeOps, the custom operation should be replaced by
353+
# `R.matmul` and `R.add`, which should in turn be replaced with
354+
# TIR implementations. Therefore, the second application of
355+
# LegalizeOps() should be a no-op.
356+
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)
357+
358+
285359
if __name__ == "__main__":
286360
tvm.testing.main()

0 commit comments

Comments
 (0)