-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Redo][Unity] Split DecomposeOpsForTraining into two steps #16465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Redo][Unity] Split DecomposeOpsForTraining into two steps #16465
Conversation
This function should be used instead of `std::regex` within C++ call sites, to avoid ABI incompatibilities with pytorch. Currently, the pytorch wheels available through pip install use the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to user the pre-C++11 ABI, this would cause breakages with dynamically-linked LLVM environments. Use of the `<regex>` header in TVM should be avoided, as its implementation is not supported by gcc's dual ABI. This ABI incompatibility results in runtime errors either when `std::regex` is called from TVM, or when `std::regex` is called from pytorch, depending on which library was loaded first. This restriction can be removed when a version of pytorch compiled using `-DUSE_CXX11_ABI=1` is available from PyPI. [0] pytorch/pytorch#51039
This is a reapplication of apache#15954, after resolving the breakages that required reverting in apache#16442. The regex matching is now implemented without the `#include <regex>` from the C++ stdlib, to avoid ABI incompatibility with pytorch. Prior to this commit, the `DecomposeOpsForTraining` transform directly replaced `relax.nn.batch_norm` into more primitive relax operations. This required the decomposed form of `relax.nn.batch_norm` to be duplicated with `DecomposeOpsForInference`. This commit refactors the pass to occur in two steps, first to apply training-specific mutations, and then to decompose. Having a clear `DecomposeOps` pass also has a clear single location for operator decomposition, which may be migrated into the operator definition in the future, similar to `FLegalize`.
8a8f8d9 to
317e8da
Compare
|
Is having the two separate passes necessary for reducing code duplication for batch norm? It does come at the cost of an extra traversal. |
|
It isn't strictly necessary, but I'd like to move in that direction as a first step in removing the Currently, there are two distinct transforms, Since the |
|
Ah, that seems like a good reason then. 👍 Some bigger simplifications in the works. |
slyubomirsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes seem reasonable and, per your comment, set us up for further simplifications down the line.
|
Sounds good. Re-running CI as I let the results get more stale than I'd like, then (assuming no new failures arise) merging in. |
This is a reapplication of #15954, after resolving the breakages that required reverting in #16442. The regex matching is now implemented without the
#include <regex>from the C++ stdlib, to avoid ABI incompatibility with pytorch.Prior to this commit, the
DecomposeOpsForTrainingtransform directly replacedrelax.nn.batch_norminto more primitive relax operations. This required the decomposed form ofrelax.nn.batch_normto be duplicated withDecomposeOpsForInference. This commit refactors the pass to occur in two steps, first to apply training-specific mutations, and then to decompose.Having a clear
DecomposeOpspass also has a clear single location for operator decomposition, which may be migrated into the operator definition in the future, similar toFLegalize.