Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layout conversion bypass for blocked to dotOperand #4538

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Aug 19, 2024

This PR extends shared memory bypass for blocked->dot operand conversions and adds bypass check in DecomposeUnsupportedConversions and ReduceDataDuplication.

This PR is a part of PR series. Final goal is to improve efficiency of small dot operations and bypass as much shared memory accesses as possible.

Rough list of PRs:

Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Overall looks good; just a few small issues.

int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
auto ctaLayout = blockedLayout.getCTALayout();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue we have in the codebase is lots of mysterious layout/indexing--it's not easy for others reading the code to pick up the intent. The following might not be that tricky; but still can we add a comment to explain what the following checks are doing in a high level?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the comments! but the wording is quite confusing to me right now. What about something like

The following logic checks that a source blocked layout B matches a destination dot operand layout with blocked layout parent P. It's considered match if 1) each thread holds a whole copy of all elements along the K dimension for B, and 2) distribution along all other non-K dimensions match between S and B. This is to guarantee that each thread have all the data needed for reduction without exchange with other threads. (And/or whatever other reasons why we want this kind of match.)

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp Outdated Show resolved Hide resolved
test/TritonGPU/reduce-data-duplication.mlir Outdated Show resolved Hide resolved
This PR extends shared memory bypass for blocked->dot operand conversions and
adds bypass check in DecomposeUnsupportedConversions and ReduceDataDuplication.
// i.e. tensor<64x32xf16, #dot_op<{opIdx=0, parent=#blocked}>> will have sizePerThread = [<depends on #blocked>, 32]
// and tensor<64x32xf16, #dot_op<{opIdx=1, parent=#blocked}>> will have sizePerThread = [64, <depends on #blocked>]
//
// For example tensor<64x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going from dot operand to blocked layout? Isn't it the reverse of what we are doing in this function? I'm also not sure the distribution is correct? Isn't this contradicting to the check at L571?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is misleading, let me change this comment.

I mean that these dot and blocked layouts are equal? I should not use "converted" here

int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
auto ctaLayout = blockedLayout.getCTALayout();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the comments! but the wording is quite confusing to me right now. What about something like

The following logic checks that a source blocked layout B matches a destination dot operand layout with blocked layout parent P. It's considered match if 1) each thread holds a whole copy of all elements along the K dimension for B, and 2) distribution along all other non-K dimensions match between S and B. This is to guarantee that each thread have all the data needed for reduction without exchange with other threads. (And/or whatever other reasons why we want this kind of match.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants