diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index cc79d45323b5..5037161fcb90 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -612,7 +612,17 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorBeginBindingBlock(); + // Because the input may not be normalized, the SeqExpr may occur + // nested within another SeqExpr. In that case, we want to use + // whatever binding-block type the parent uses, so that we any + // bindings collected into the prologue will be compatible with + // the parent block. + if (block_stack_.size() && CurrentBlockIsDataFlow()) { + this->BeginDataflowBlock(); + } else { + this->BeginBindingBlock(); + } + // the body may not be a leaf expression, so check for that Expr new_body = this->NormalizeArgument(op->body); unchanged &= new_body.same_as(op->body); diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 4469f3558593..170967d28281 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -57,10 +57,11 @@ class LegalizeMutator : public ExprMutator { public: explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap, bool enable_warning) - : ExprMutator(mod), - mod_(std::move(mod)), - cmap_(std::move(cmap)), - enable_warning_(enable_warning) {} + : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { + if (cmap) { + cmap_ = std::move(cmap.value()); + } + } IRModule Transform() { for (const auto& [gv, func] : mod_->functions) { @@ -132,36 +133,67 @@ class LegalizeMutator : public ExprMutator { return visited_call; } - // Priority: customize > default. - // Check if it has customize legalization registered. - if (cmap_.defined() && cmap_.value().count(op->name)) { - auto ret = cmap_.value()[op->name](this->builder_, visited_call); - if (ret.IsObjectRef() && WrapPureCondition(op, ret.AsObjectRef())) { - return WrapPureCall(Downcast(ret.AsObjectRef())); + FLegalize legalization_func; + + if (auto opt_custom_legalize = cmap_.Get(op->name)) { + // First choice, use a custom legalization function + legalization_func = opt_custom_legalize.value(); + } else if (legalize_map.count(op)) { + // Second choice, use a default legalization + legalization_func = legalize_map[op]; + } else { + // No legalization. + if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op && + op != call_pure_packed_op) { + LOG(WARNING) << "No legalization func for " << op->name << " is found."; } - return ret; + return visited_call; } - // Check if it has default legalization registered. - if (legalize_map.count(op)) { - auto ret = legalize_map[op](this->builder_, visited_call); - if (WrapPureCondition(op, ret)) { - return WrapPureCall(Downcast(ret)); - } - return ret; + + // The legalization function may call `builder_->Emit()` as part + // of its implementation. In that case, any operations it emits + // must be caught such that they be checked for recursive + // legalization. This is done by wrapping the legalized value in + // a SeqExpr, which can first be visited, then unwrapped by the + // normalization. + if (builder_->CurrentBlockIsDataFlow()) { + builder_->BeginDataflowBlock(); + } else { + builder_->BeginBindingBlock(); } + Expr legalized = legalization_func(builder_, visited_call); + legalized = builder_->Normalize(legalized); - // No legalization. - if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op && - op != call_pure_packed_op) { - LOG(WARNING) << "No legalization func for " << op->name << " is found."; + BindingBlock prologue = builder_->EndBlock(); + for (const auto& binding : prologue->bindings) { + VisitBinding(binding); } - return visited_call; + + if (WrapPureCondition(op, legalized)) { + legalized = WrapPureCall(Downcast(legalized)); + } + + // Legalization may have introduced additional operations that + // must be legalized as well. For example, a user-custom + // intrinsic whose legalization is implemented in terms of relax + // intrinsics. The base case of the recursion occurs when no + // additional legalization steps are found. + // + // Only perform recursive legalization when the legalization + // function returned a modified expression, as some legalizations + // return the original expression if they are unable to produce a + // legalized version. + if (!legalized.same_as(visited_call)) { + legalized = VisitExpr(legalized); + } + + return legalized; } /*! \brief The context IRModule. */ IRModule mod_; /*! \brief The customized legalization function map. */ - Optional> cmap_; + Map cmap_; /*! * \brief A boolean value indicating if to print warnings for CallNode whose op's * legalization function is not registered. diff --git a/tests/python/relax/test_transform_legalize_ops.py b/tests/python/relax/test_transform_legalize_ops.py index af6004bd0af5..47eeb68341b3 100644 --- a/tests/python/relax/test_transform_legalize_ops.py +++ b/tests/python/relax/test_transform_legalize_ops.py @@ -24,6 +24,8 @@ from tvm.script import relax as R, tir as T, ir as I import tvm.testing +import pytest + def test_customize_legalize(): # fmt: off @@ -282,5 +284,77 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 8]): assert err_message.startswith("To legalize R.matmul") +emit_legalization_through_builder = tvm.testing.parameter( + by_dict={ + "return_relax_expr": False, + "return_relax_var": True, + } +) + + +@pytest.fixture +def custom_op(emit_legalization_through_builder): + op_name = "custom_op.matmul_bias_add" + + def infer_struct_info(call: relax.Call, context): + activations, weight, bias = call.args + + matmul_call = relax.op.matmul(activations, weight) + matmul_sinfo = tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")( + matmul_call, context + ) + + matmul_var = relax.Var("dummy_var", matmul_sinfo) + add_call = matmul_var + bias + add_sinfo = tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context) + + return add_sinfo + + def legalize(bb: relax.BlockBuilder, call: relax.Call): + activations, weight, bias = call.args + legalized = relax.op.matmul(activations, weight) + bias + if emit_legalization_through_builder: + legalized = bb.emit(legalized) + return legalized + + op_attrs = { + "FInferStructInfo": infer_struct_info, + "FLegalize": legalize, + "FPurity": True, + } + + for key, value in op_attrs.items(): + tvm.ir.register_op_attr(op_name, key, value) + + op = tvm.ir.Op.get(op_name) + yield op + + for key in op_attrs: + op.reset_attr(key) + + +def test_recursive_legalization(custom_op): + """Legalization of an operator may produce new operators requiring legalization""" + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([16, 32, 64], "float32"), + Weight: R.Tensor([64, 128], "float32"), + Bias: R.Tensor([16, 32, 128], "float32"), + ): + return relax.Call(custom_op, [A, Weight, Bias]) + + AfterFirstIter = LegalizeOps()(Before) + AfterSecondIter = LegalizeOps()(AfterFirstIter) + + # After LegalizeOps, the custom operation should be replaced by + # `R.matmul` and `R.add`, which should in turn be replaced with + # TIR implementations. Therefore, the second application of + # LegalizeOps() should be a no-op. + tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter) + + if __name__ == "__main__": tvm.testing.main()