-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[Reland][NVIDIA] Support swizzle 0 TMA + MMA for Hopper and Blackwell #10148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
72 commits
Select commit
Hold shift + click to select a range
8317097
Add swizzle=0 TCGen5 operand-view memdesc rewrite and lit test
masahi 1939857
cmake fix
masahi 7d1e42c
works
masahi a86d083
make it work for other dot ops
masahi d2955e7
fix
masahi 28d35fa
fix
masahi 638c3b0
[TritonGPU] Match swizzle0 operand-view rewrite from local_load sourc…
masahi 3375a12
[TritonGPU] Use source shared encoding for swizzle0 operand-view rewrite
masahi 9f559e9
fix
masahi 390b118
clean
masahi 3782068
simplify
masahi 8707f6d
remove pattern matching against desc load
masahi 5ea9724
upd lit test
masahi 12cb8e0
fix
masahi 07119d3
fix for bw
masahi 746c28a
update bw lit
masahi 1d02e00
update for hop
masahi be6eb93
upd
masahi 0fa2e71
upd
masahi 5e45dac
clean test
masahi e7d54f8
refactoring operand update
masahi 3291122
wip
masahi 6637c0d
more
masahi 9dcce40
refactor
masahi 9144860
wip
masahi da8d60c
fix
masahi a41052a
more clean
masahi d3eee96
add comment
masahi b9b6eb4
remove stale include
masahi 0699532
Merge branch 'main' into tma-mma-swizzle-0
masahi 2cda92b
add comment describing the rewrite pattern
masahi dcf62c0
minor
masahi 6163ab9
Merge branch 'main' into tma-mma-swizzle-0
masahi 8aec72f
revert cmake change
masahi fbae09b
update comment to make it more accurate
masahi 4b986f3
Merge branch 'main' into tma-mma-swizzle-0
masahi e01ce66
Make swizzle0 operand view rewrite sink-driven
masahi c388478
Clean up sink-driven dot operand rewrite
masahi b9bb708
Refine sink-driven operand rewrite checks
masahi 1133abd
Generalize dot operand view rewrite naming
masahi ae6782c
Remove stale swizzle0 host descriptor test
masahi ffa4f6f
revert unnecessary test change
masahi 9679359
Restore template dispatch for dot operand updates
masahi e315bf2
Use inferSrcEncoding in dot operand rewrite
masahi 02dcdba
Simplify dot operand rewiring after rewrite
masahi 68fe5ac
Move MMA operand view rewrite into NVIDIA pass
masahi df2f6f9
Simplify MMA operand view rewrite
masahi 52f2848
precommit
masahi a77c439
Revert to the old backward inference impl, run the pass before ODE
masahi 6e07bb6
pre commit
masahi 9766e51
Merge branch 'main' into tma-mma-swizzle-0
masahi e093192
Update descriptor rewrite for new tensordesc type
masahi 4f97dc1
Keep descriptor layouts non-transposed
masahi 3dee2de
Simplify MMA operand view replay steps
masahi f70af5b
Use DotOpInterface in MMA view rewrite
masahi 87eb143
Move MMA operand view rewrite into ODO
masahi 72859e0
precommit
masahi 3130b82
inline helpers
masahi 6879828
Merge branch 'tma-mma-swizzle-0' into swizzle-0-fix
masahi a3e56e0
[TritonNvidiaGPU] Avoid fatal ODE layout probes
masahi 102193e
[TritonGPU] Restrict operand-view rewrite to shared_linear
masahi 30f12e1
Merge branch 'main' into swizzle-0-fix
masahi 62b3a08
format
masahi de8af7a
[TritonGPU] Simplify TMA block shape diagnostics
masahi 2bb6ccd
[TritonGPU] Simplify TMA block shape error helper
masahi 6963198
[TritonGPU] Drop stale TMA helper suffixes
masahi ed11c36
minor change in LinearLayoutConversions.cpp
masahi ff09a8a
inline error emit
masahi b7f1acd
more inline error msg
masahi 0171037
remove tryGetTMABlockShape
masahi 58c3b95
removed tryNvmmaSharedToLinearLayout by adding a safe version of nvmm…
masahi 157f580
Merge branch 'main' into swizzle-0-fix
masahi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,16 +193,15 @@ LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared, | |
| return LinearLayout({{S("offset"), bases2D}}, outDimNames); | ||
| } | ||
|
|
||
| LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape, | ||
| NVMMASharedEncodingAttr shared, | ||
| TMAMode mode, bool disableSwizzle) { | ||
| static FailureOr<LinearLayout> buildNvmmaSharedLinearLayout( | ||
| ArrayRef<int64_t> shape, NVMMASharedEncodingAttr shared, | ||
| ArrayRef<int64_t> tmaShape, bool disableSwizzle, bool emitErrors) { | ||
| if (!llvm::all_of(tmaShape, llvm::isPowerOf2_64)) | ||
| return failure(); | ||
| MLIRContext *ctx = shared.getContext(); | ||
| int rank = shape.size(); | ||
| auto shapePerCTA = getShapePerCTA(shared, shape); | ||
| auto kOffset = S("offset"); | ||
| auto tmaShape = | ||
| triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA, | ||
| /*packedSize=*/true, mode); | ||
| if (shared.getSwizzlingByteWidth() == 0) { | ||
| auto outDimNames = standardOutDimNames(ctx, rank); | ||
| LinearLayout layout = LinearLayout::identity1D(tmaShape[rank - 1], kOffset, | ||
|
|
@@ -234,20 +233,23 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape, | |
| int packingFactor = shared.getFp4Padded() ? 2 : 1; | ||
| if (collapsedTmaShape[1] * packingFactor < tileCols || | ||
| collapsedTmaShape[0] < tileRows) { | ||
| llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA to " | ||
| "be at least [" | ||
| << tileRows << ", " << (tileCols / packingFactor) | ||
| << "], collapsedTmaShape: [" << collapsedTmaShape[0] << ", " | ||
| << collapsedTmaShape[1] << "]\n"; | ||
| llvm::report_fatal_error("Illegal shared layout"); | ||
| if (emitErrors) { | ||
| llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA " | ||
| "to be at least [" | ||
| << tileRows << ", " << (tileCols / packingFactor) | ||
| << "], collapsedTmaShape: [" << collapsedTmaShape[0] << ", " | ||
| << collapsedTmaShape[1] << "]\n"; | ||
| } | ||
| return failure(); | ||
| } | ||
|
|
||
| // Distribute the remaining rows and cols. | ||
| auto layout = | ||
| ensureLayoutNotSmallerThan(tileLayout, outDimNames, collapsedTmaShape); | ||
|
|
||
| // Reshape the layout to the N-D pre-transposed shape per CTA. | ||
| SmallVector<int64_t> maybeTransposedTmaShape = tmaShape; | ||
| SmallVector<int64_t> maybeTransposedTmaShape(tmaShape.begin(), | ||
| tmaShape.end()); | ||
| if (shared.getTransposed()) { | ||
| // Move the outer dim to the inner position. | ||
| // TODO: we should move back to using `order` instead of transposed to make | ||
|
|
@@ -256,6 +258,10 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape, | |
| maybeTransposedTmaShape.begin() + 1, | ||
| maybeTransposedTmaShape.end()); | ||
| } | ||
| // This condition can fail if a layout is speculatively constructed for | ||
| // equivalence checking. | ||
| if (layout.getTotalOutDimSize() != product(maybeTransposedTmaShape)) | ||
| return failure(); | ||
| auto reshapedLayout = reshapeLayout(ctx, layout, maybeTransposedTmaShape); | ||
|
|
||
| if (shared.getTransposed()) { | ||
|
|
@@ -272,6 +278,42 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape, | |
| return combineCtaCgaWithShape(reshapedLayout, shared.getCGALayout(), shape); | ||
| } | ||
|
|
||
| LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape, | ||
| NVMMASharedEncodingAttr shared, | ||
| TMAMode mode, bool disableSwizzle) { | ||
| auto layout = nvmmaSharedToLinearLayout(shape, shared, mode, disableSwizzle, | ||
| /*emitErrors=*/true); | ||
| if (failed(layout)) | ||
| llvm::report_fatal_error("Illegal shared layout"); | ||
|
Comment on lines
+286
to
+287
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is fine to keep that along with the emitError
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope the current code after 58c3b95 has addressed this comment |
||
| return *layout; | ||
| } | ||
|
|
||
| FailureOr<LinearLayout> | ||
| nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape, | ||
| NVMMASharedEncodingAttr shared, TMAMode mode, | ||
| bool disableSwizzle, bool emitErrors) { | ||
| auto shapePerCTA = getShapePerCTA(shared, shape); | ||
| SmallVector<int64_t> tmaShape; | ||
| if (emitErrors) { | ||
| tmaShape = | ||
| getTMABlockShape(shapePerCTA, shared.getElementBitWidth(), | ||
| shared.getSwizzlingByteWidth(), shared.getFp4Padded(), | ||
| shared.getTransposed(), /*packedSize=*/true, mode); | ||
| } else { | ||
| auto maybeTmaShape = | ||
| getTMABlockShape(shapePerCTA, shared.getElementBitWidth(), | ||
| shared.getSwizzlingByteWidth(), shared.getFp4Padded(), | ||
| shared.getTransposed(), /*packedSize=*/true, | ||
| /*emitError=*/nullptr, mode); | ||
| if (failed(maybeTmaShape)) | ||
| return failure(); | ||
| tmaShape = *maybeTmaShape; | ||
| } | ||
|
|
||
| return buildNvmmaSharedLinearLayout(shape, shared, tmaShape, disableSwizzle, | ||
| emitErrors); | ||
| } | ||
|
|
||
| /// Function to generate lane and warp layout for dot operands. | ||
| static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, | ||
| ArrayRef<unsigned> shape, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where can this happen exactly? it feels like quite a big issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A concrete test case that fails this condition is this one: https://github.com/masahi/triton/blob/58c3b956958f572e1f6bfa3ddbd865c9cac40763/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir#L180-L188
We call
buildNvmmaSharedLinearLayoutwith shape [1, 16, 1, 16] and various candidates nvmma_shared encodings. For some candidates, it seemsensureLayoutNotSmallerThancan return a layout that covers more than [1, 16, 1, 16] output elements.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still get the feeling that there's a better place to catch this one than this late, but sure.