Skip to content

[Backend] Fix incorrect shared layout for dot operands rank==3#4944

Closed
AlexAUT wants to merge 1 commit intotriton-lang:mainfrom
AlexAUT:fixInvalidSharedLayoutForDot3d
Closed

[Backend] Fix incorrect shared layout for dot operands rank==3#4944
AlexAUT wants to merge 1 commit intotriton-lang:mainfrom
AlexAUT:fixInvalidSharedLayoutForDot3d

Conversation

@AlexAUT
Copy link
Copy Markdown
Contributor

@AlexAUT AlexAUT commented Oct 17, 2024

#4904 moved the layout rewrite for dot operands from blocked->mma to blocked->shared->mma for wmma from ReduceDataDeplucation to DecomposeUnsupportedConversion with this change. However, DecomposeUnsupportedConversion was missing the special case for rank==3 which is copied over in this PR.

@antiagainst antiagainst changed the title [AMD][Backend] Fix incorrect shared layout for wmma dot operands for rank==3 [Backend] Fix incorrect shared layout for wmma dot operands for rank==3 Oct 17, 2024
@antiagainst antiagainst changed the title [Backend] Fix incorrect shared layout for wmma dot operands for rank==3 [Backend] Fix incorrect shared layout for dot operands rank==3 Oct 17, 2024
@AlexAUT AlexAUT force-pushed the fixInvalidSharedLayoutForDot3d branch 2 times, most recently from ff0d816 to 33941e2 Compare October 18, 2024 08:40
Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to wmma were not intentional, it was land race with #4538. Feel free to revert the change making ReduceDataDuplicaton's condition apply only to Ampere.

That being said, would it make sense here to simply use dstOrder as the order to be used, similar to how it's done in

scratchConfig.order = outOrd;

This one already gives you the order you want for rank=3, while the current one could end up in a funny state where the order does not match that of the input or the output.

@AlexAUT AlexAUT force-pushed the fixInvalidSharedLayoutForDot3d branch from 33941e2 to fab393d Compare October 18, 2024 16:52
@AlexAUT
Copy link
Copy Markdown
Contributor Author

AlexAUT commented Oct 18, 2024

Thanks for the feedback. I reverted the change in ReduceDataDuplication.

I also switched to getThreadOrder(dstEncoding) and added it to DecomposeUnsupportedConversions to prevent a future regression.

@AlexAUT AlexAUT force-pushed the fixInvalidSharedLayoutForDot3d branch 3 times, most recently from d6b00d2 to 2d4751d Compare October 18, 2024 18:15
Comment on lines +105 to +109
if (rank == 3) {
sharedOrder = gpu::getThreadOrder(dstEncoding);
} else {
sharedOrder = srcOrder;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are going this route, you probably want to do it for all ranks, otherwise this heuristic would be incredibly counterintuitive.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then we might change the shared layout order for rank != 3 with this change? I can also revert all changes except for the condition in ReduceDataDuplication to make it ampere specific. Then we have the same behavior as before #4904.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just find very weird the current behaviour for rank==3. @Jokeren thoughts?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree the special condition is weird. Can you attach a test case for us to take a look at?

Copy link
Copy Markdown
Contributor Author

@AlexAUT AlexAUT Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you run python/test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float32] it will trigger this assert. This happens when wmma is used for the dot. The shared -> dot layout conversion for wmma will also expect that the batch dim is the slowest dimension.

@antiagainst antiagainst marked this pull request as ready for review October 18, 2024 22:28
@antiagainst antiagainst requested a review from ptillet as a code owner October 18, 2024 22:28
@AlexAUT AlexAUT force-pushed the fixInvalidSharedLayoutForDot3d branch from 2d4751d to 2e796e6 Compare October 19, 2024 11:03
@AlexAUT
Copy link
Copy Markdown
Contributor Author

AlexAUT commented Oct 22, 2024

#4950 changed removed the condition which restores the old behavior for wmma so this problem is fixed now. Thanks @lezcano.

@AlexAUT AlexAUT closed this Oct 22, 2024
@AlexAUT AlexAUT deleted the fixInvalidSharedLayoutForDot3d branch February 6, 2025 11:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants