Skip to content

[AMD] Add basics to allow bypass LDS for dot RHS#4856

Closed
oplavsic wants to merge 6 commits intotriton-lang:mainfrom
oplavsic:bypass_lds_upstream
Closed

[AMD] Add basics to allow bypass LDS for dot RHS#4856
oplavsic wants to merge 6 commits intotriton-lang:mainfrom
oplavsic:bypass_lds_upstream

Conversation

@oplavsic
Copy link
Copy Markdown
Contributor

@oplavsic oplavsic commented Oct 4, 2024

The AMDBypassLDSForDotOperandPass implements a strategy to bypass using the
Local Data Share (LDS) for one of the operands in an MFMA dot operation.

Under certain conditions, the dot layout of one of the operands allows direct
loading from HBM to VGPRs in the MFMA dot layout, without losing of vectorization of global loads
or increasing the number of global loads due to shared data between threads.
The required conditions are:

  1. K-Major Tensor Layout:
    The operand we want to bypass LDS for must be K-major (i.e., row-major for
    operand 0 or column-major for operand 1). This supports vectorized global
    load instructions, as MFMA instructions require each thread to hold B
    operand elements along the K dimension.
  2. kWidth * sizeof(dataType) == 128:
    Using the maximum kWidth for a specific data type ensures optimal global
    load vectorization (e.g., using global_load_dwordx4 instructions).
  3. Single Warp per CTA Dimension:
    Either warpsPerCTA[ndim] == 1 for operand A bypass or warpsPerCTA[mDim] ==
    1 for operand B bypass. This guarantees that each tensor element is
    handled by exactly one thread, maintaining the same number of global loads
    as in the blocked layout (i.e., each element is loaded only once).

Comment thread lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

// Limit shared memory sharing to width >= 32 elements.
LDBG("Load " << *loadOp << " has width " << width);
if (width < 32) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamPipelineV2.cpp change in this PR enables pipelining in registers. This change was suggested to me by Simon.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sjw36 I think this change is impactful enough. Should we extract it out as a separate pull request and consider the implications over the broader cases instead of coupled with this pull request?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it could be posted separately. @oplavsic there should be other cases that will exercise your pass right?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now in #5227. Let's drop it here.

@oplavsic oplavsic force-pushed the bypass_lds_upstream branch from a6839de to 66512b5 Compare October 7, 2024 12:07
@oplavsic oplavsic changed the title [WIP][AMD] Add AMDBypassLDSForDotOperandPass [AMD] Add AMDBypassLDSForDotOperandPass Oct 8, 2024
int getNVIDIAComputeCapability(Operation *module);

// Convert \param op operands and results to layout \param encoding.
void convertOpEncoding(Attribute encoding, Operation *op);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved (and renamed) this function from Coalece.cpp so I could use it in BypassLDS pass since I needed this exact functionality.

@oplavsic oplavsic force-pushed the bypass_lds_upstream branch from 9f85a9c to a30240e Compare October 9, 2024 13:43
Copy link
Copy Markdown
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A couple of comments.

Comment thread third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotOperand.cpp Outdated
Comment thread third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotOperand.cpp Outdated
Comment thread test/TritonGPU/amd/bypass-lds.mlir
Comment thread test/TritonGPU/amd/bypass-lds.mlir Outdated
Comment thread third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp
Comment thread lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Comment thread lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Comment thread lib/Dialect/TritonGPU/IR/Dialect.cpp Outdated
@oplavsic oplavsic force-pushed the bypass_lds_upstream branch 3 times, most recently from 0f2b989 to 6240fb3 Compare October 22, 2024 23:21
Copy link
Copy Markdown
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you merge in original/main to trigger CI?

Comment thread test/TritonGPU/combine.mlir

// Limit shared memory sharing to width >= 32 elements.
LDBG("Load " << *loadOp << " has width " << width);
if (width < 32) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sjw36 I think this change is impactful enough. Should we extract it out as a separate pull request and consider the implications over the broader cases instead of coupled with this pull request?

@oplavsic oplavsic force-pushed the bypass_lds_upstream branch from 6240fb3 to b27350e Compare October 23, 2024 12:17
Copy link
Copy Markdown
Contributor

@sjw36 sjw36 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the algorithm for efficiency.

ModuleOp module = getOperation();
auto convertOps = collectConvertOps(module);

module.dump();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove debug module.dump().

ModuleOp &mod) {
SmallVector<triton::LoadOp> loadOpsVec;

mod.walk([&](triton::LoadOp loadOp) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is expensive for every convert_layout. Just walk use-def chain instead.

auto srcType = dyn_cast<RankedTensorType>(convertOp.getOperand().getType());
auto dstType = dyn_cast<RankedTensorType>(convertOp.getType());

if (!srcType || !dstType || srcType.getShape().size() != 2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the shape restricted to 2 dimensions?


// Limit shared memory sharing to width >= 32 elements.
LDBG("Load " << *loadOp << " has width " << width);
if (width < 32) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it could be posted separately. @oplavsic there should be other cases that will exercise your pass right?

@antiagainst antiagainst changed the title [AMD] Add AMDBypassLDSForDotOperandPass [AMD] Add basics to allow bypass LDS for dot RHS Dec 3, 2024
@antiagainst
Copy link
Copy Markdown
Member

@oplavsic can you merge in changes from main and address remaining comments? Also please put the new functionality behind an env var so we can turn it on to evaluate before making it generally enabled.

@antiagainst
Copy link
Copy Markdown
Member

Closing this one given #5350 supersedes it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants