Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,17 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
unchanged &= new_block.same_as(block);
}

this->BeginBindingBlock();
// Because the input may not be normalized, the SeqExpr may occur
// nested within another SeqExpr. In that case, we want to use
// whatever binding-block type the parent uses, so that we any
// bindings collected into the prologue will be compatible with
// the parent block.
if (block_stack_.size() && CurrentBlockIsDataFlow()) {
this->BeginDataflowBlock();
} else {
this->BeginBindingBlock();
}

// the body may not be a leaf expression, so check for that
Expr new_body = this->NormalizeArgument(op->body);
unchanged &= new_body.same_as(op->body);
Expand Down
80 changes: 56 additions & 24 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ class LegalizeMutator : public ExprMutator {
public:
explicit LegalizeMutator(const IRModule& mod, const Optional<Map<String, PackedFunc>>& cmap,
bool enable_warning)
: ExprMutator(mod),
mod_(std::move(mod)),
cmap_(std::move(cmap)),
enable_warning_(enable_warning) {}
: ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) {
if (cmap) {
cmap_ = std::move(cmap.value());
}
}

IRModule Transform() {
for (const auto& [gv, func] : mod_->functions) {
Expand Down Expand Up @@ -132,36 +133,67 @@ class LegalizeMutator : public ExprMutator {
return visited_call;
}

// Priority: customize > default.
// Check if it has customize legalization registered.
if (cmap_.defined() && cmap_.value().count(op->name)) {
auto ret = cmap_.value()[op->name](this->builder_, visited_call);
if (ret.IsObjectRef<Expr>() && WrapPureCondition(op, ret.AsObjectRef<Expr>())) {
return WrapPureCall(Downcast<Call>(ret.AsObjectRef<Expr>()));
FLegalize legalization_func;

if (auto opt_custom_legalize = cmap_.Get(op->name)) {
// First choice, use a custom legalization function
legalization_func = opt_custom_legalize.value();
} else if (legalize_map.count(op)) {
// Second choice, use a default legalization
legalization_func = legalize_map[op];
} else {
// No legalization.
if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
op != call_pure_packed_op) {
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
}
return ret;
return visited_call;
}
// Check if it has default legalization registered.
if (legalize_map.count(op)) {
auto ret = legalize_map[op](this->builder_, visited_call);
if (WrapPureCondition(op, ret)) {
return WrapPureCall(Downcast<Call>(ret));
}
return ret;

// The legalization function may call `builder_->Emit()` as part
// of its implementation. In that case, any operations it emits
// must be caught such that they be checked for recursive
// legalization. This is done by wrapping the legalized value in
// a SeqExpr, which can first be visited, then unwrapped by the
// normalization.
if (builder_->CurrentBlockIsDataFlow()) {
builder_->BeginDataflowBlock();
} else {
builder_->BeginBindingBlock();
}
Expr legalized = legalization_func(builder_, visited_call);
legalized = builder_->Normalize(legalized);

// No legalization.
if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
op != call_pure_packed_op) {
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
BindingBlock prologue = builder_->EndBlock();
for (const auto& binding : prologue->bindings) {
VisitBinding(binding);
}
return visited_call;

if (WrapPureCondition(op, legalized)) {
legalized = WrapPureCall(Downcast<Call>(legalized));
}

// Legalization may have introduced additional operations that
// must be legalized as well. For example, a user-custom
// intrinsic whose legalization is implemented in terms of relax
// intrinsics. The base case of the recursion occurs when no
// additional legalization steps are found.
//
// Only perform recursive legalization when the legalization
// function returned a modified expression, as some legalizations
// return the original expression if they are unable to produce a
// legalized version.
if (!legalized.same_as(visited_call)) {
legalized = VisitExpr(legalized);
}

return legalized;
}

/*! \brief The context IRModule. */
IRModule mod_;
/*! \brief The customized legalization function map. */
Optional<Map<String, PackedFunc>> cmap_;
Map<String, PackedFunc> cmap_;
/*!
* \brief A boolean value indicating if to print warnings for CallNode whose op's
* legalization function is not registered.
Expand Down
74 changes: 74 additions & 0 deletions tests/python/relax/test_transform_legalize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from tvm.script import relax as R, tir as T, ir as I
import tvm.testing

import pytest


def test_customize_legalize():
# fmt: off
Expand Down Expand Up @@ -282,5 +284,77 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 8]):
assert err_message.startswith("To legalize R.matmul")


emit_legalization_through_builder = tvm.testing.parameter(
by_dict={
"return_relax_expr": False,
"return_relax_var": True,
}
)


@pytest.fixture
def custom_op(emit_legalization_through_builder):
op_name = "custom_op.matmul_bias_add"

def infer_struct_info(call: relax.Call, context):
activations, weight, bias = call.args

matmul_call = relax.op.matmul(activations, weight)
matmul_sinfo = tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
matmul_call, context
)

matmul_var = relax.Var("dummy_var", matmul_sinfo)
add_call = matmul_var + bias
add_sinfo = tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)

return add_sinfo

def legalize(bb: relax.BlockBuilder, call: relax.Call):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a similar pass although this does not support any recursion: https://github.com/apache/tvm/blob/unity/python/tvm/relax/transform/transform.py#L994

Is there any use-case for recursion? Or is it more like a future-proof?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a couple of reasons I'd been thinking of, most of which fall somewhere between future-planning and user-friendliness. (Bit of a brain dump as follows.)

  1. User-friendliness to make it easier to write legalization steps. For example, R.nn.rms_norm could be written in terms of R.std instead of requiring a direct lowering to a TIR implementation.
  2. Future-planning for user-defined custom intrinsics. If the legalization of these custom operators is defined in terms of standard relax operators, LegalizeOps would need to recursively expand them to allow AnnotateTIROpPattern to recognize the results.
  3. Future-planning for partial legalization. If each operator has a "composite_level", then we could selectively lower operators that are above some level of complexity. This would be a generalization of the OpDecomposer, to decompose any
  4. Future-planning for defining the requirements of graph-level optimization passes. If an optimization pass handles all relax operators up to some composite_level, new operators could be added without impacting that optimization pass, so long as those operators define a partial legalization that decomposes it.
  5. Centralizing the definition of each operator. With composite operators defined in terms of lower-complexity operators, the OpDecomposer could be identical to the rules used by LegalizeOps, avoiding duplicate operator definitions.
  6. Future-planning to minimize the need for TIR pattern recognition. For example, R.nn.attention is implemented in terms of topi.transpose and topi.reshape, and would require pattern-matching similar to RewriteDataflowReshape to un-lower these back to Relax operations. If R.nn.attention were instead decomposed into R.permute_dims and R.reshape, we'd get this for free.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @Lunderberg for kind explanation. I like the idea of "composite-level" and centralizing the definitions. Can we check if DecomposeOpsForInference and DecomposeOpsForTraining can be supported with this PR to see if we can replace them? If so, we can discuss about their deprecation as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a look, the DecomposeOpsFor* passes are currently doing two distinct roles. The first role is to lower the relax.nn.batch_norm, relax.nn.layer_norm, and relax.tensor_to_shape operators into lower-level relax implementations. The second role is to mutate the relax.nn.batch_norm operator into a training-specific version.

I think the first role of lowering relax operators into less complex Relax operators will be supported by the partial lowering intended for LegalizeOps. The second role is independent to the legalization, and would be best kept as a standalone pass. The second role would become much simpler, as the relax.nn.batch_norm(data, gamma, beta, prev_mean, prev_var) could be updated to relax.nn.batch_norm(data, gamma, beta, weighted_avg(mean(data), prev_mean), weighted_avg(var(data), prev_var)), rather than needing a full definition of relax.nn.batch_norm.

Though, those are probably changes that would be best for a follow-up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I did not know there is a training-specific version of batch norm. SGTM. Let's discuss about it in the follow-up PR.

activations, weight, bias = call.args
legalized = relax.op.matmul(activations, weight) + bias
if emit_legalization_through_builder:
legalized = bb.emit(legalized)
return legalized

op_attrs = {
"FInferStructInfo": infer_struct_info,
"FLegalize": legalize,
"FPurity": True,
}

for key, value in op_attrs.items():
tvm.ir.register_op_attr(op_name, key, value)

op = tvm.ir.Op.get(op_name)
yield op

for key in op_attrs:
op.reset_attr(key)


def test_recursive_legalization(custom_op):
"""Legalization of an operator may produce new operators requiring legalization"""

@I.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16, 32, 64], "float32"),
Weight: R.Tensor([64, 128], "float32"),
Bias: R.Tensor([16, 32, 128], "float32"),
):
return relax.Call(custom_op, [A, Weight, Bias])

AfterFirstIter = LegalizeOps()(Before)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does user need to perform LegalizeOps passes depending on their custom ops? For example, user needs to call twice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. With this change, the LegalizeOps pass will continue until no additional legalization can be applied, so the user only needs to call the function once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry. I missed that you do equality check between AfterFirstIter and AfterSecondIter. Make sense to me.

AfterSecondIter = LegalizeOps()(AfterFirstIter)

# After LegalizeOps, the custom operation should be replaced by
# `R.matmul` and `R.add`, which should in turn be replaced with
# TIR implementations. Therefore, the second application of
# LegalizeOps() should be a no-op.
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)


if __name__ == "__main__":
tvm.testing.main()