-
Notifications
You must be signed in to change notification settings - Fork 18
[mlir-tensorrt] Transpose Reshape Elimination pass #686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@christopherbate @shelkesagar29 This is the big PR to eliminate shuffles from the program. Ideally, this should include deleting the |
|
If the pass can subsume the existing two passes, that's fine. I'd say keep the existing test files in-place and just update the RUN command and FileCheck directives -- this way we can see what changed with respect to tests. A follow on commit can merge the test files if required. |
| return success(); | ||
| } | ||
| }; | ||
| } // namespace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: everything above line 436 of TransposeReshapeElimination.cpp came from the existing TransposeElimination.cpp file
f173be8 to
b6d32fd
Compare
| @@ -1,4 +1,4 @@ | |||
| // RUN: tensorrt-opt %s -split-input-file -tensorrt-reshape-elimination | FileCheck %s | |||
| // RUN: tensorrt-opt %s -split-input-file -tensorrt-transpose-reshape-elimination | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@christopherbate I have updated the lit test for reshape and transpose to use the new transpose-reshape-elimination pass
|
We shouldn't be copy-pasting code and having two versions of the same pattern... either delete the old pattern passes or use a I'm doing manual copy/paste right now to understand what new code was actually added, which is a bit tedious. In the future, if you could have separate commits for pure code movement vs. changes/new code, then that would be great. |
christopherbate
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding Matul-to-Einsum and Einsum-to-Matmul:
We should avoid the question of which one is better entirely. I really want to avoid making the TRT dialect transforms reason about which one is better since it's an impossible question for encoding into a heuristic. It's not easy or straightforward for us to figure out a general rule regarding where einsum might be better since it depends on the internal details of TensorRT, the TensorRT version, etc, and it is not easily to make robust to future changes.
For that reason, I would recommend to users to make this decision much closer to the frontend. For example, in the stablehlo conversions we explicitly have an option for whether or not to prefer use of tensorrt.einsum or tensorrt.matrix_multiply.
The biggest issue I can see for Einsum vs. MatMul is for issues regarding pattern recognition of special fusions internal to TRT (e.g. MHA) or some either niche feature or optimization that a user may be expecting to see based off of TRT documentation or some other communication. To my knowledge, there's no special pattern that requires einsum, so I would be hesitant to change all matmul to einsum in the main TRT dialect compilation pipeline unless its guarded by an option. There are other users who are using TRT dialect directly at the frontend, and this would be a surprising change for them IMO.
That said, from my personal experience, it's Ok to "prefer einsum" as long as you have a mechanism for side-steping any potential issues related to fusion optimizations that might be ciritical to your workloads. In the case of MHA or other critical fusions for popular models right now, we are adding a tensorrt.attention operation, which should obviate the issue with respect to einsum vs. matmul.
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Show resolved
Hide resolved
|
Thanks for the through review @christopherbate. I will start working to address the comments on the PR
The original motivation for eliminating shuffles/transpose/reshapes was that it broke pattern matching in TensorRT in the first place. See this bug for additional context: https://partners.nvidia.com/Bug/ViewBug/5381960 Based on observations that I have made of TensorRT, it seems like there is no issue using einsum to represent matrix multiplies, in that it generates the same kernels (but I can't be 100% sure of this given that TRT is closed source and I don't have access to its source). |
c960e12 to
21f60d1
Compare
|
Hi @christopherbate, I have address the comments you added to the PR and updated the PR. |
21f60d1 to
5486d22
Compare
|
@matthewfl sorry for the delay; I'll re-review tomorrow and check that it passes internal tests |
shelkesagar29
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently there are two issues because of which e2e models like whisper Jax and GPT are failing.
- In
MoveReshapeBeforeTransposepattern. I have commented there with minimal failing test case. - In
EinsumPushDownTransposepattern. I have commented with failing test and how to fix the issue.
Once these two are fixed, we are good to merge this one.
This is really good work. Thanks
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir
Outdated
Show resolved
Hide resolved
| int64_t cost2 = memoryCost(consumer.getType()) + memoryCost(op1.getType()); | ||
| LLVM_DEBUG(DBGS() << "cost1=" << cost1 << ", cost2=" << cost2 << "\n"); | ||
| if (cost1 == 0 && cost2 == 0) | ||
| return {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we return one valid op here?
On call site, we are not checking for null op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fixes ping-ponging and a non-terminating infinite loop on the following test
c3b4d85#diff-e275a939421ea167dbf564d5fe9b866e458b4ee87f70c39e08cccbd333d7b44eR474
| // einsum allows is more flexible with the inputs, braodcasting dimensions and | ||
| // transposing. Hence, we can easily implement rewrites that merge transpose | ||
| // into einsum and push reshape through an einsum | ||
| class MatmulToEinsum : public OpRewritePattern<tensorrt::MatrixMultiplyOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As Chris commented before, we don't want to generally convert MatMul to Einsum.
Is einsum created by this pattern is converted back to matmul?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The einsums are converted back to matmul if they can match the matmul pattern after all of the transposes have been eliminated.
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
| // Create an new transpose op from an einsum. Rearrange the output axes to | ||
| // match the ordering of the input axes This should enable converting the einsum | ||
| // back to a matmul einsum(x1, x2, ...) -> transpose(einsum(x1, x2, ...)) | ||
| class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Nit]
Name of the pattern is confusing. It reads like we are pushing transpose above einsum below it.
Probably we can use another name like CanonicalizeEinsumForMatmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few patterns that are helping canonicalize einsum to match matmul. There is the EinsumPushDownTranspose, EinsumPushUpTranspose and EinsumPushUpMultipleMulitipliedAxes.
|
@shelkesagar29 Hi Sagar, I pushed the code that makes sure that the multihead attention fusion is able to still work. You should be able to test it on your models now. Let me know if you want me to rebase this on the main branch or squash the commits. I only added new commits so it should be easier to merge with the copy you were testing. |
Hi Matthew, |
|
@shelkesagar29 Thanks for the report. I just pushed a fixed for the issue that you reported. |
a1c0d1e to
95f586c
Compare
|
@shelkesagar29 I just pushed a fix for the bug you sent via email. |
shelkesagar29
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All tests pass.
We can merge the PR after following are done.
- Can you please check one more time for adhering to LLVM code style? https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements I could find many places where this doesn't hold.
- Address current comments
- Rebase into single commit with canonical and clear message. This is big change.
Thank you so much for time and patience.
mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
574e157 to
50e9d57
Compare
This is a new pass that is designed to replace the Transpose
and Reshape Elemination passes. This pass adds a lot of new rewrite
rules which enable pushing the transposes and reshapes around so that
they can be combined and then eliminated.
The motivation for this pass is that there are some cases where shuffles
can get inserted around matrix multiplications and element wise ops
which break various fusions inside of TensorRT.
To accomplish this, this pass uses several rewrite rules that push transposes
and reshapes around to combine them into identity transposes and reshapes which
can be eliminated from the program. The rewrite rules are as follows:
1. "canonicalize" the network into simpler ops
- `shuffle(x)` -> `reshape(transpose(reshape(x)))`
- `matrix_multiply(x, y)` -> `einsum("ij,jk->ik", x, y)`
- `expand_rank(x)` -> `reshape(x)`
- `collapse_rank(x)` -> `reshape(x)`
2. Push down `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible
- `einsum(transpose(x), ...)` -> `einsum(x, ...)` Merge transpose into einsum
- `einsum(...)` -> `transpose(einsum(...))` Pull transpose out of einsum (to try to match matrix multiply pattern)
- `einsum(reshape(x), y, ...)` -> `transpose(reshape(einsum(x, reshape(transpose(y)), ...)))` Push reshape down. Possibly add reshape and transposes to other inputs as needed. Conditioned on heuristic checking if "better"
- `unary(transpose(x))` -> `transpose(unary(x))`
- `activation(transpose(x))` -> `transpose(activation(x))`
- `identity_op(transpose(x))` -> `transpose(identity_op(x))`
- `activation(reshape(x))` -> `reshape(activation(x))`
- `unary(reshape(x))` -> `reshape(unary(x))`
- `identity_op(reshape(x))` -> `reshape(identity_op(x))`
- `reshape(transpose(x))` -> `transpose(reshape(x))` if possible put reshape before transpose
- `qdq(transpose(x))` -> `transpose(qdq(x))` if the scale is 0-dim
- `qdq(reshape(x))` -> `reshape(qdq(x))` if the scale is 0-dim
- `reshape(reshape(x))` -> `reshape(x)`
- `transpose(transpose(x))` -> `transpose(x)`
- `reshape(x)` -> `x` if `reshape` is identity
- `transpose(x)` -> `x` if `transpose` is identity
- `elementwise(reshape(a), b)` -> `reshape(elementwise(a, reshape(b)))` conditioned on heuristic
- `elementwise(transpose(a), b)` -> `transpose(elementwise(a, transpose(b)))`
- `softmax(transpose(x))` -> `transpose(softmax(x))`
- `softmax(reshape(x))` -> `reshape(softmax(x))`
3. Push up `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible
- `transpose(einsum(...))` -> `einsum(...)`. Merge transpose into einsum
- `einsum(...)` -> `einsum(transpose(x), ...)`. Pull transposes out of einsum (to try to match matrix multiply pattern)
- `reshape(einsum(...))` -> `einsum(reshape(transpose(x)), ...)` Push reshapes up through einsum. Adding transposes as needed
- `transpose(activation(x))` -> `activation(transpose(x))`
- `transpose(unary(x))` -> `unary(transpose(x))`
- `transpose(identity_op(x))` -> `identity_op(transpose(x))`
- `reshape(activation(x))` -> `activation(reshape(x))`
- `reshape(unary(x))` -> `unary(reshape(x))`
- `reshape(identity_op(x))` -> `identity_op(reshape(x))`
- `reshape(reshape(x))` -> `reshape(x)`
- `transpose(transpose(x))` -> `transpose(x)`
- `reshape(x)` -> `x` if `reshape` is identity
- `transpose(x)` -> `x` if `transpose` is identity
- `transpose(reshape(x))` -> `reshape(transpose(x))` if possible put transpose before reshape
- `transpose(qdq(x))` -> `qdq(transpose(x))` if the scale is 0-dim
- `reshape(qdq(x))` -> `qdq(reshape(x))` if the scale is 0-dim
- `reshape(elementwise(a, b))` -> `elementwise(reshape(a), reshape(b))`
- `transpose(elementwise(a, b))` -> `elementwise(transpose(a), transpose(b))`
- `transpose(softmax(x))` -> `softmax(transpose(x))`
- `reshape(softmax(x))` -> `softmax(reshape(x))`
4. Convert back to matrix multiplication form to assist with TRT's pattern matching
- `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern
- `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible
5. Final clean ups, additional merging of transpose/reshapes into leftover einsums
- `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern
- `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible
- `transpose(einsum(...))` -> `einsum(...)`
- `einsum(tranpose(x), ...)` -> `einsum(...)`
- `einsum(collapse_rank(x), ...)` -> `einsum(...)`
- `expand_rank(einsum(...))` -> `einsum(...)`
50e9d57 to
a7d0a6b
Compare
|
@shelkesagar29 I updated the PR. Cleaned up the formatting. Squashed the commits down to a single commit and rebased on the current You can go ahead and click merge on the PR. |
Done. Thanks for all the hard work and patience |
The
TransposeReshapeEliminationpass is designed to subsume the existing Transpose and Reshape Elimination passes. This pass adds a lot of new rewrite rules which enable pushing the transposes and reshapes around so that they can be combined and then eliminated. The rules from the existingTransposeEliminationare copied into theTransposeReshapeElimination.cppfile. The rules from theReshapeEliminationpass should be subsumed by the rules added to theTransposeReshapeElimination.The motivation for this pass is that there are some cases where shuffles can get inserted around matrix multiplications and element wise ops which break various fusions inside of TensorRT.
The process is as follows:
shuffle(x)->reshape(transpose(reshape(x)))matrix_multiply(x, y)->einsum("ij,jk->ik", x, y)expand_rank(x)->reshape(x)collapse_rank(x)->reshape(x)reshapeandtransposeops as much as possible. Merging and eliminating when possibleeinsum(transpose(x), ...)->einsum(x, ...)Merge transpose into einsumeinsum(...)->transpose(einsum(...))Pull transpose out of einsum (to try to match matrix multiply pattern)einsum(reshape(x), y, ...)->transpose(reshape(einsum(x, reshape(transpose(y)), ...)))Push reshape down. Possibly add reshape and transposes to other inputs as needed. Conditioned on heuristic checking if "better"unary(transpose(x))->transpose(unary(x))activation(transpose(x))->transpose(activation(x))identity_op(transpose(x))->transpose(identity_op(x))activation(reshape(x))->reshape(activation(x))unary(reshape(x))->reshape(unary(x))identity_op(reshape(x))->reshape(identity_op(x))reshape(transpose(x))->transpose(reshape(x))if possible put reshape before transposeqdq(transpose(x))->transpose(qdq(x))if the scale is 0-dimqdq(reshape(x))->reshape(qdq(x))if the scale is 0-dimreshape(reshape(x))->reshape(x)transpose(transpose(x))->transpose(x)reshape(x)->xifreshapeis identitytranspose(x)->xiftransposeis identityelementwise(reshape(a), b)->reshape(elementwise(a, reshape(b)))conditioned on heuristicelementwise(transpose(a), b)->transpose(elementwise(a, transpose(b)))softmax(transpose(x))->transpose(softmax(x))softmax(reshape(x))->reshape(softmax(x))reshapeandtransposeops as much as possible. Merging and eliminating when possibletranspose(einsum(...))->einsum(...). Merge transpose into einsumeinsum(...)->einsum(transpose(x), ...). Pull transposes out of einsum (to try to match matrix multiply pattern)reshape(einsum(...))->einsum(reshape(transpose(x)), ...)Push reshapes up through einsum. Adding transposes as neededtranspose(activation(x))->activation(transpose(x))transpose(unary(x))->unary(transpose(x))transpose(identity_op(x))->identity_op(transpose(x))reshape(activation(x))->activation(reshape(x))reshape(unary(x))->unary(reshape(x))reshape(identity_op(x))->identity_op(reshape(x))reshape(reshape(x))->reshape(x)transpose(transpose(x))->transpose(x)reshape(x)->xifreshapeis identitytranspose(x)->xiftransposeis identitytranspose(reshape(x))->reshape(transpose(x))if possible put transpose before reshapetranspose(qdq(x))->qdq(transpose(x))if the scale is 0-dimreshape(qdq(x))->qdq(reshape(x))if the scale is 0-dimreshape(elementwise(a, b))->elementwise(reshape(a), reshape(b))transpose(elementwise(a, b))->elementwise(transpose(a), transpose(b))transpose(softmax(x))->softmax(transpose(x))reshape(softmax(x))->softmax(reshape(x))einsum(x, y)->matrix_multiply(x, y)if einsum matches a matrix multiply patternmatrix_multiply(transpose(x), y)->matrix_multiply(x, y)merge transpose if possibleeinsum(x, y)->matrix_multiply(x, y)if einsum matches a matrix multiply patternmatrix_multiply(transpose(x), y)->matrix_multiply(x, y)merge transpose if possibletranspose(einsum(...))->einsum(...)einsum(tranpose(x), ...)->einsum(...)einsum(collapse_rank(x), ...)->einsum(...)expand_rank(einsum(...))->einsum(...)NOTE: The overarching goal of this PR is to improve the pattern matching inside of TensorRT (and therefore the quality of kernel's that TensorRT can generate, and fusion that TensorRT will generate). I have some empirical evidence that the mlir that is generated seems to be be an improvement, however I am still not 100% sure what is the best way to generate mlir in some of these edge cases when it comes to getting the fastest model out of TensorRT.