-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Unity] Allow FLegalize to produce Relax operations #15842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Allow FLegalize to produce Relax operations #15842
Conversation
Prior to this commit, a `FLegalize` function needed to produce an implementation that can be used as input by `relax.transform.AnnotateTIROpPattern`, and could not lower to other relax operations. This commit allows Relax operations to be included in the output of `FLegalize`, with the result being further legalized if required.
sunggg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @Lunderberg! A couple questions.
| ): | ||
| return relax.Call(custom_op, [A, Weight, Bias]) | ||
|
|
||
| AfterFirstIter = LegalizeOps()(Before) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does user need to perform LegalizeOps passes depending on their custom ops? For example, user needs to call twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. With this change, the LegalizeOps pass will continue until no additional legalization can be applied, so the user only needs to call the function once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry. I missed that you do equality check between AfterFirstIter and AfterSecondIter. Make sense to me.
|
|
||
| return add_sinfo | ||
|
|
||
| def legalize(bb: relax.BlockBuilder, call: relax.Call): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a similar pass although this does not support any recursion: https://github.com/apache/tvm/blob/unity/python/tvm/relax/transform/transform.py#L994
Is there any use-case for recursion? Or is it more like a future-proof?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a couple of reasons I'd been thinking of, most of which fall somewhere between future-planning and user-friendliness. (Bit of a brain dump as follows.)
- User-friendliness to make it easier to write legalization steps. For example,
R.nn.rms_normcould be written in terms ofR.stdinstead of requiring a direct lowering to a TIR implementation. - Future-planning for user-defined custom intrinsics. If the legalization of these custom operators is defined in terms of standard relax operators,
LegalizeOpswould need to recursively expand them to allowAnnotateTIROpPatternto recognize the results. - Future-planning for partial legalization. If each operator has a "composite_level", then we could selectively lower operators that are above some level of complexity. This would be a generalization of the
OpDecomposer, to decompose any - Future-planning for defining the requirements of graph-level optimization passes. If an optimization pass handles all relax operators up to some
composite_level, new operators could be added without impacting that optimization pass, so long as those operators define a partial legalization that decomposes it. - Centralizing the definition of each operator. With composite operators defined in terms of lower-complexity operators, the
OpDecomposercould be identical to the rules used byLegalizeOps, avoiding duplicate operator definitions. - Future-planning to minimize the need for TIR pattern recognition. For example,
R.nn.attentionis implemented in terms oftopi.transposeandtopi.reshape, and would require pattern-matching similar toRewriteDataflowReshapeto un-lower these back to Relax operations. IfR.nn.attentionwere instead decomposed intoR.permute_dimsandR.reshape, we'd get this for free.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @Lunderberg for kind explanation. I like the idea of "composite-level" and centralizing the definitions. Can we check if DecomposeOpsForInference and DecomposeOpsForTraining can be supported with this PR to see if we can replace them? If so, we can discuss about their deprecation as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking a look, the DecomposeOpsFor* passes are currently doing two distinct roles. The first role is to lower the relax.nn.batch_norm, relax.nn.layer_norm, and relax.tensor_to_shape operators into lower-level relax implementations. The second role is to mutate the relax.nn.batch_norm operator into a training-specific version.
I think the first role of lowering relax operators into less complex Relax operators will be supported by the partial lowering intended for LegalizeOps. The second role is independent to the legalization, and would be best kept as a standalone pass. The second role would become much simpler, as the relax.nn.batch_norm(data, gamma, beta, prev_mean, prev_var) could be updated to relax.nn.batch_norm(data, gamma, beta, weighted_avg(mean(data), prev_mean), weighted_avg(var(data), prev_var)), rather than needing a full definition of relax.nn.batch_norm.
Though, those are probably changes that would be best for a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! I did not know there is a training-specific version of batch norm. SGTM. Let's discuss about it in the follow-up PR.
Prior to this commit, a
FLegalizefunction needed to produce an implementation that can be used as input byrelax.transform.AnnotateTIROpPattern, and could not lower to other relax operations. This commit allows Relax operations to be included in the output ofFLegalize, with the result being further legalized if required.