Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WarpSpec] add support for multiple channels sharing the same smem #9

Merged
merged 10 commits into from
Jan 14, 2025

Conversation

manman-ren
Copy link
Contributor

@manman-ren manman-ren commented Dec 14, 2024

location

Summary: We already have channelsGroupedByProducers and channelsGroupedByConsumers. For one-producer-multi-consumer mode, a single buffer will be used, channelsGroupedByProducers is used for this. channelsGroupedByConsumers is to minimize the insertion of sync primitives, a single set of communication ops will be inserted.

For this patch, we want to share the same smem location for multiple channels that are live in different loop nests. We add allocation.shareGroup attributes to the local_allocs corresponding to channels that reuse the same smem location.

In order to reuse the same smem location, we update bufferIdx and phase through all the loop nests that share smem locations. We handle the following cases:

for # persistent loop
  for # can be nested under if
  for # can be nested under if
Or
for # can be nested under if
for # can be nested under if
Or
for # persistent loop
  for # can be nested under if

The generated code will look like

for(accumLoopCount)
  t1 = IfOp
    forOp # loop A
    tmpIdx = accumLoopCount + numStepsA
    yield tmpIdx
    else yield accumLoopCount
  t2 = IfOp
    forOp # loop B
    tmpIdx = t1 + numStepsB
    yield tmpIdx
    else yield t1
  yield t2 for accumLoopCount

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 14, 2024
@manman-ren
Copy link
Contributor Author

manman-ren commented Dec 14, 2024

The implementation makes changes to appendBufferIdxArgs/createNewLoop to add an argument in outer loop for accumLoopCount or to add a constant for a place holder when there is no outer loop. It also changes specializeIfOp to create a result for the if to propagate the accumLoopCount.
We then use a helper function updateAccumLoopCount to correctly link up the values.

Phase 1:

ForOp with accumLoopCount as an argument
   If
      use accumLoopCount to set initialBufferIdx
      ForOp
      generate numSteps and create an add op for accumLoopCount + numSteps
  Yield for ForOp with accumLoopCount (this will be updated later in updateAccumLoopCount)

@htyu
Copy link
Contributor

htyu commented Dec 18, 2024

This is great work, thanks!

BTW, can you include a lit test to help understand what this PR do exactly?

if (kv.second.size() <= 1)
continue;
bufferMap[kv.first].getDefiningOp()->setAttr(
"allocation.shareGroup",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question, why is this needed if same buffer is already used on the IR?

Copy link
Contributor

@htyu htyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great with the refactoring! Just left some nit comments so far.

@@ -1673,6 +2185,8 @@ class TritonGPUWSCodePartitionPass
funcOp.dump();
});

// Assuming there are no changes to loops in loopWithBufferReuse.
DenseMap<AsyncTaskId, Value> mapForAccumLoopVar;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be used anywhere.

SmallVector<Operation *> loopWithBufferReuse;
reuseBuffers(asyncTaskTopOps, channels, mapToRepresenting,
loopWithBufferReuse);
unsigned loopsWithAccumLoopCount = loopWithBufferReuse.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loopsWithAccumLoopCount no longer used?

loopWithBufferReuse);
unsigned loopsWithAccumLoopCount = loopWithBufferReuse.size();
// Use and update loopWithBufferReuse.
Value tmpAccumLoopCount =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return values is unused?

}

// For ForOps in taskTopOps, create new ForOp for each by adding phase,
// bufferIdx to the arguments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update comment to reflect the accumulatedLoopCount arg?

builder.setInsertionPoint(taskTopOps[0]);
tmpAccumLoopCount = builder.createWithAsyncTaskIds<arith::ConstantIntOp>(
oneFor->getLoc(), 0, 64);
// populateLoopSteps(loopWithBufferReuse, accumLoopCountsAfterLoop,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the comment?

// numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep
Value numSteps = getNumSteps(forOp, builder);

// TODO: use a global flattened iteration space index for multi-dim loops.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add accumulatedLoopCount to non-sharing loop nest too, we can just unify this path with the hasParallelReuse path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is right. We can use accumulatedLoopCount which is more accurate for persistent kernels when the inner loop has varying numSteps. accumulatedLoopCount will be an argument for the outer persistent loop.

});
std::for_each(
liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) {
if (buffer->regionIds.size() > 1 || buffer->sharingGroup >= 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this extra check? Is that for buffer sharing in single-consumer mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we have buffer sharing, we want to be conservative and mark its live range as the full live range. Otherwise, we will need to combine the buffers in the same sharing group and analyze the union of regions, and the union of live ranges.

location

Summary: We already have channelsGroupedByProducers and
channelsGroupedByConsumers. For one-producer-multi-consumer mode,
a single buffer will be used, channelsGroupedByProducers is used
for this. channelsGroupedByConsumers is to minimize the insertion of
sync primitives, a single set of communication ops will be inserted.

For this patch, we want to share the same smem location for multiple
channels that are live in different loop nests. We add
allocation.shareGroup attributes to the local_allocs corresponding to
channels that reuse the same smem location.

In order to reuse the same smem location, we update bufferIdx and phase
through all the loop nests that share smem locations. We handle the
following cases:
for # persistent loop
  for # can be nested under if
  for # can be nested under if
Or
for # can be nested under if
for # can be nested under if
Or
for # persistent loop
  for # can be nested under if

The generated code will look like
for(accumLoopCount)
  t1 = IfOp
    forOp # loop A
    tmpIdx = accumLoopCount + numStepsA
    yield tmpIdx
    else yield accumLoopCount
  t2 = IfOp
    forOp # loop B
    tmpIdx = t1 + numStepsB
    yield tmpIdx
    else yield t1
  yield t2 for accumLoopCount

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary: half done, buildable

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@manman-ren manman-ren force-pushed the mren/ws-reuse-buffer branch from 0d062fe to 40a5741 Compare January 14, 2025 02:26
@manman-ren manman-ren merged commit c286564 into ws Jan 14, 2025
2 checks passed
@manman-ren manman-ren deleted the mren/ws-reuse-buffer branch January 16, 2025 02:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants