[Backend] Fix incorrect shared layout for dot operands rank==3#4944
[Backend] Fix incorrect shared layout for dot operands rank==3#4944AlexAUT wants to merge 1 commit intotriton-lang:mainfrom
Conversation
ff0d816 to
33941e2
Compare
lezcano
left a comment
There was a problem hiding this comment.
The changes to wmma were not intentional, it was land race with #4538. Feel free to revert the change making ReduceDataDuplicaton's condition apply only to Ampere.
That being said, would it make sense here to simply use dstOrder as the order to be used, similar to how it's done in
triton/lib/Analysis/Allocation.cpp
Line 124 in 692143c
This one already gives you the order you want for
rank=3, while the current one could end up in a funny state where the order does not match that of the input or the output.
33941e2 to
fab393d
Compare
|
Thanks for the feedback. I reverted the change in I also switched to |
d6b00d2 to
2d4751d
Compare
| if (rank == 3) { | ||
| sharedOrder = gpu::getThreadOrder(dstEncoding); | ||
| } else { | ||
| sharedOrder = srcOrder; | ||
| } |
There was a problem hiding this comment.
If you are going this route, you probably want to do it for all ranks, otherwise this heuristic would be incredibly counterintuitive.
There was a problem hiding this comment.
But then we might change the shared layout order for rank != 3 with this change? I can also revert all changes except for the condition in ReduceDataDuplication to make it ampere specific. Then we have the same behavior as before #4904.
There was a problem hiding this comment.
I just find very weird the current behaviour for rank==3. @Jokeren thoughts?
There was a problem hiding this comment.
Actually block to dot decomposition can be deprecated soon
main...keren/dot-mma#diff-30fb59df648f6c4ec5db24c51ef728a8d56104e151bd52e8d482b6951c207aa7R95
There was a problem hiding this comment.
I agree the special condition is weird. Can you attach a test case for us to take a look at?
There was a problem hiding this comment.
When you run python/test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float32] it will trigger this assert. This happens when wmma is used for the dot. The shared -> dot layout conversion for wmma will also expect that the batch dim is the slowest dimension.
2d4751d to
2e796e6
Compare
#4904 moved the layout rewrite for dot operands from
blocked->mmatoblocked->shared->mmaforwmmafromReduceDataDeplucationtoDecomposeUnsupportedConversionwith this change. However,DecomposeUnsupportedConversionwas missing the special case forrank==3which is copied over in this PR.