-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Unity] Allow FLegalize to produce Relax operations #15842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f0c2990
54b8193
a7753c8
81f65e9
d3fe19e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,8 @@ | |
| from tvm.script import relax as R, tir as T, ir as I | ||
| import tvm.testing | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| def test_customize_legalize(): | ||
| # fmt: off | ||
|
|
@@ -282,5 +284,77 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 8]): | |
| assert err_message.startswith("To legalize R.matmul") | ||
|
|
||
|
|
||
| emit_legalization_through_builder = tvm.testing.parameter( | ||
| by_dict={ | ||
| "return_relax_expr": False, | ||
| "return_relax_var": True, | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def custom_op(emit_legalization_through_builder): | ||
| op_name = "custom_op.matmul_bias_add" | ||
|
|
||
| def infer_struct_info(call: relax.Call, context): | ||
| activations, weight, bias = call.args | ||
|
|
||
| matmul_call = relax.op.matmul(activations, weight) | ||
| matmul_sinfo = tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")( | ||
| matmul_call, context | ||
| ) | ||
|
|
||
| matmul_var = relax.Var("dummy_var", matmul_sinfo) | ||
| add_call = matmul_var + bias | ||
| add_sinfo = tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context) | ||
|
|
||
| return add_sinfo | ||
|
|
||
| def legalize(bb: relax.BlockBuilder, call: relax.Call): | ||
| activations, weight, bias = call.args | ||
| legalized = relax.op.matmul(activations, weight) + bias | ||
| if emit_legalization_through_builder: | ||
| legalized = bb.emit(legalized) | ||
| return legalized | ||
|
|
||
| op_attrs = { | ||
| "FInferStructInfo": infer_struct_info, | ||
| "FLegalize": legalize, | ||
| "FPurity": True, | ||
| } | ||
|
|
||
| for key, value in op_attrs.items(): | ||
| tvm.ir.register_op_attr(op_name, key, value) | ||
|
|
||
| op = tvm.ir.Op.get(op_name) | ||
| yield op | ||
|
|
||
| for key in op_attrs: | ||
| op.reset_attr(key) | ||
|
|
||
|
|
||
| def test_recursive_legalization(custom_op): | ||
| """Legalization of an operator may produce new operators requiring legalization""" | ||
|
|
||
| @I.ir_module | ||
| class Before: | ||
| @R.function | ||
| def main( | ||
| A: R.Tensor([16, 32, 64], "float32"), | ||
| Weight: R.Tensor([64, 128], "float32"), | ||
| Bias: R.Tensor([16, 32, 128], "float32"), | ||
| ): | ||
| return relax.Call(custom_op, [A, Weight, Bias]) | ||
|
|
||
| AfterFirstIter = LegalizeOps()(Before) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does user need to perform
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope. With this change, the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, sorry. I missed that you do equality check between |
||
| AfterSecondIter = LegalizeOps()(AfterFirstIter) | ||
|
|
||
| # After LegalizeOps, the custom operation should be replaced by | ||
| # `R.matmul` and `R.add`, which should in turn be replaced with | ||
| # TIR implementations. Therefore, the second application of | ||
| # LegalizeOps() should be a no-op. | ||
| tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a similar pass although this does not support any recursion: https://github.com/apache/tvm/blob/unity/python/tvm/relax/transform/transform.py#L994
Is there any use-case for recursion? Or is it more like a future-proof?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a couple of reasons I'd been thinking of, most of which fall somewhere between future-planning and user-friendliness. (Bit of a brain dump as follows.)
R.nn.rms_normcould be written in terms ofR.stdinstead of requiring a direct lowering to a TIR implementation.LegalizeOpswould need to recursively expand them to allowAnnotateTIROpPatternto recognize the results.OpDecomposer, to decompose anycomposite_level, new operators could be added without impacting that optimization pass, so long as those operators define a partial legalization that decomposes it.OpDecomposercould be identical to the rules used byLegalizeOps, avoiding duplicate operator definitions.R.nn.attentionis implemented in terms oftopi.transposeandtopi.reshape, and would require pattern-matching similar toRewriteDataflowReshapeto un-lower these back to Relax operations. IfR.nn.attentionwere instead decomposed intoR.permute_dimsandR.reshape, we'd get this for free.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @Lunderberg for kind explanation. I like the idea of "composite-level" and centralizing the definitions. Can we check if
DecomposeOpsForInferenceandDecomposeOpsForTrainingcan be supported with this PR to see if we can replace them? If so, we can discuss about their deprecation as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking a look, the
DecomposeOpsFor*passes are currently doing two distinct roles. The first role is to lower therelax.nn.batch_norm,relax.nn.layer_norm, andrelax.tensor_to_shapeoperators into lower-level relax implementations. The second role is to mutate therelax.nn.batch_normoperator into a training-specific version.I think the first role of lowering relax operators into less complex Relax operators will be supported by the partial lowering intended for
LegalizeOps. The second role is independent to the legalization, and would be best kept as a standalone pass. The second role would become much simpler, as therelax.nn.batch_norm(data, gamma, beta, prev_mean, prev_var)could be updated torelax.nn.batch_norm(data, gamma, beta, weighted_avg(mean(data), prev_mean), weighted_avg(var(data), prev_var)), rather than needing a full definition ofrelax.nn.batch_norm.Though, those are probably changes that would be best for a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! I did not know there is a training-specific version of batch norm. SGTM. Let's discuss about it in the follow-up PR.