-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WarpSpec] add support for multiple channels sharing the same smem #9
Conversation
The implementation makes changes to appendBufferIdxArgs/createNewLoop to add an argument in outer loop for accumLoopCount or to add a constant for a place holder when there is no outer loop. It also changes specializeIfOp to create a result for the if to propagate the accumLoopCount. Phase 1:
|
This is great work, thanks! BTW, can you include a lit test to help understand what this PR do exactly? |
1015fdc
to
492969a
Compare
if (kv.second.size() <= 1) | ||
continue; | ||
bufferMap[kv.first].getDefiningOp()->setAttr( | ||
"allocation.shareGroup", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dumb question, why is this needed if same buffer is already used on the IR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great with the refactoring! Just left some nit comments so far.
@@ -1673,6 +2185,8 @@ class TritonGPUWSCodePartitionPass | |||
funcOp.dump(); | |||
}); | |||
|
|||
// Assuming there are no changes to loops in loopWithBufferReuse. | |||
DenseMap<AsyncTaskId, Value> mapForAccumLoopVar; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be used anywhere.
SmallVector<Operation *> loopWithBufferReuse; | ||
reuseBuffers(asyncTaskTopOps, channels, mapToRepresenting, | ||
loopWithBufferReuse); | ||
unsigned loopsWithAccumLoopCount = loopWithBufferReuse.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loopsWithAccumLoopCount
no longer used?
loopWithBufferReuse); | ||
unsigned loopsWithAccumLoopCount = loopWithBufferReuse.size(); | ||
// Use and update loopWithBufferReuse. | ||
Value tmpAccumLoopCount = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return values is unused?
} | ||
|
||
// For ForOps in taskTopOps, create new ForOp for each by adding phase, | ||
// bufferIdx to the arguments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update comment to reflect the accumulatedLoopCount arg?
builder.setInsertionPoint(taskTopOps[0]); | ||
tmpAccumLoopCount = builder.createWithAsyncTaskIds<arith::ConstantIntOp>( | ||
oneFor->getLoc(), 0, 64); | ||
// populateLoopSteps(loopWithBufferReuse, accumLoopCountsAfterLoop, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the comment?
// numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep | ||
Value numSteps = getNumSteps(forOp, builder); | ||
|
||
// TODO: use a global flattened iteration space index for multi-dim loops. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add accumulatedLoopCount to non-sharing loop nest too, we can just unify this path with the hasParallelReuse
path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is right. We can use accumulatedLoopCount which is more accurate for persistent kernels when the inner loop has varying numSteps. accumulatedLoopCount will be an argument for the outer persistent loop.
}); | ||
std::for_each( | ||
liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { | ||
if (buffer->regionIds.size() > 1 || buffer->sharingGroup >= 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this extra check? Is that for buffer sharing in single-consumer mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we have buffer sharing, we want to be conservative and mark its live range as the full live range. Otherwise, we will need to combine the buffers in the same sharing group and analyze the union of regions, and the union of live ranges.
location Summary: We already have channelsGroupedByProducers and channelsGroupedByConsumers. For one-producer-multi-consumer mode, a single buffer will be used, channelsGroupedByProducers is used for this. channelsGroupedByConsumers is to minimize the insertion of sync primitives, a single set of communication ops will be inserted. For this patch, we want to share the same smem location for multiple channels that are live in different loop nests. We add allocation.shareGroup attributes to the local_allocs corresponding to channels that reuse the same smem location. In order to reuse the same smem location, we update bufferIdx and phase through all the loop nests that share smem locations. We handle the following cases: for # persistent loop for # can be nested under if for # can be nested under if Or for # can be nested under if for # can be nested under if Or for # persistent loop for # can be nested under if The generated code will look like for(accumLoopCount) t1 = IfOp forOp # loop A tmpIdx = accumLoopCount + numStepsA yield tmpIdx else yield accumLoopCount t2 = IfOp forOp # loop B tmpIdx = t1 + numStepsB yield tmpIdx else yield t1 yield t2 for accumLoopCount Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: half done, buildable Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
0d062fe
to
40a5741
Compare
location
Summary: We already have channelsGroupedByProducers and channelsGroupedByConsumers. For one-producer-multi-consumer mode, a single buffer will be used, channelsGroupedByProducers is used for this. channelsGroupedByConsumers is to minimize the insertion of sync primitives, a single set of communication ops will be inserted.
For this patch, we want to share the same smem location for multiple channels that are live in different loop nests. We add allocation.shareGroup attributes to the local_allocs corresponding to channels that reuse the same smem location.
In order to reuse the same smem location, we update bufferIdx and phase through all the loop nests that share smem locations. We handle the following cases:
The generated code will look like