[AMD] Add basics to allow bypass LDS for dot RHS#4856
[AMD] Add basics to allow bypass LDS for dot RHS#4856oplavsic wants to merge 6 commits intotriton-lang:mainfrom
Conversation
|
|
||
| // Limit shared memory sharing to width >= 32 elements. | ||
| LDBG("Load " << *loadOp << " has width " << width); | ||
| if (width < 32) { |
There was a problem hiding this comment.
StreamPipelineV2.cpp change in this PR enables pipelining in registers. This change was suggested to me by Simon.
There was a problem hiding this comment.
@sjw36 I think this change is impactful enough. Should we extract it out as a separate pull request and consider the implications over the broader cases instead of coupled with this pull request?
There was a problem hiding this comment.
Yes, it could be posted separately. @oplavsic there should be other cases that will exercise your pass right?
a6839de to
66512b5
Compare
| int getNVIDIAComputeCapability(Operation *module); | ||
|
|
||
| // Convert \param op operands and results to layout \param encoding. | ||
| void convertOpEncoding(Attribute encoding, Operation *op); |
There was a problem hiding this comment.
I moved (and renamed) this function from Coalece.cpp so I could use it in BypassLDS pass since I needed this exact functionality.
9f85a9c to
a30240e
Compare
antiagainst
left a comment
There was a problem hiding this comment.
Thanks! A couple of comments.
0f2b989 to
6240fb3
Compare
antiagainst
left a comment
There was a problem hiding this comment.
Thanks! Could you merge in original/main to trigger CI?
|
|
||
| // Limit shared memory sharing to width >= 32 elements. | ||
| LDBG("Load " << *loadOp << " has width " << width); | ||
| if (width < 32) { |
There was a problem hiding this comment.
@sjw36 I think this change is impactful enough. Should we extract it out as a separate pull request and consider the implications over the broader cases instead of coupled with this pull request?
6240fb3 to
b27350e
Compare
sjw36
left a comment
There was a problem hiding this comment.
Please update the algorithm for efficiency.
| ModuleOp module = getOperation(); | ||
| auto convertOps = collectConvertOps(module); | ||
|
|
||
| module.dump(); |
| ModuleOp &mod) { | ||
| SmallVector<triton::LoadOp> loadOpsVec; | ||
|
|
||
| mod.walk([&](triton::LoadOp loadOp) { |
There was a problem hiding this comment.
This is expensive for every convert_layout. Just walk use-def chain instead.
| auto srcType = dyn_cast<RankedTensorType>(convertOp.getOperand().getType()); | ||
| auto dstType = dyn_cast<RankedTensorType>(convertOp.getType()); | ||
|
|
||
| if (!srcType || !dstType || srcType.getShape().size() != 2) |
There was a problem hiding this comment.
Why is the shape restricted to 2 dimensions?
|
|
||
| // Limit shared memory sharing to width >= 32 elements. | ||
| LDBG("Load " << *loadOp << " has width " << width); | ||
| if (width < 32) { |
There was a problem hiding this comment.
Yes, it could be posted separately. @oplavsic there should be other cases that will exercise your pass right?
|
@oplavsic can you merge in changes from |
|
Closing this one given #5350 supersedes it. |
The AMDBypassLDSForDotOperandPass implements a strategy to bypass using the
Local Data Share (LDS) for one of the operands in an MFMA dot operation.
Under certain conditions, the dot layout of one of the operands allows direct
loading from HBM to VGPRs in the MFMA dot layout, without losing of vectorization of global loads
or increasing the number of global loads due to shared data between threads.
The required conditions are:
The operand we want to bypass LDS for must be K-major (i.e., row-major for
operand 0 or column-major for operand 1). This supports vectorized global
load instructions, as MFMA instructions require each thread to hold B
operand elements along the K dimension.
Using the maximum kWidth for a specific data type ensures optimal global
load vectorization (e.g., using global_load_dwordx4 instructions).
Either warpsPerCTA[ndim] == 1 for operand A bypass or warpsPerCTA[mDim] ==
1 for operand B bypass. This guarantees that each tensor element is
handled by exactly one thread, maintaining the same number of global loads
as in the blocked layout (i.e., each element is loaded only once).