Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class Allocation {
size_t size;
size_t alignment;
size_t offset;
SetVector<int> regionIds;

bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }
Expand Down
79 changes: 68 additions & 11 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;

#define DEBUG_TYPE "allocation-analysis"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -330,23 +334,25 @@ class AllocationAnalysis {
/// Computes the liveness range of the allocated value.
/// Each buffer is allocated only once.
void resolveExplicitBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
function_ref<Interval<size_t>(Value value, BufferT *buffer)>
getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value);
bufferRange[buffer] = getLiveness(value, buffer);
}
}

/// Extends the liveness range by unionizing the liveness range of the aliased
/// values because each allocated buffer could be an alias of others, if block
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
function_ref<Interval<size_t>(Value value, BufferT *buffer)>
getLiveness) {
for (auto aliasBufferIter : allocation->aliasBuffer) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
auto range = getLiveness(value);
auto range = getLiveness(value, buffers.front());
for (auto *buffer : buffers) {
auto minId = range.start();
auto maxId = range.end();
Expand Down Expand Up @@ -378,11 +384,14 @@ class AllocationAnalysis {
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
} else {
// FIXME: This range makes scratch buffers used in warp-specialized
// regions conflict with everything else in the program, which is
// too conservative, but safe. A better approach would make them
// conflict with buffers live in other warp-specialized regions.
bufferRange.insert({buffer, Interval<size_t>(0, operationId.size())});
for (auto tId : getAsyncTaskIds(op))
buffer->regionIds.insert(tId);
// For warp-specialized code, we can assume each region has its own
// copy of a scratch buffer, i.e each region is for a single taskId.
// In that case, we don't need to extend the liveness of scratch
// buffers.
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
}
}
};
Expand Down Expand Up @@ -414,13 +423,27 @@ class AllocationAnalysis {

// Analyze liveness of explicit buffers
Liveness liveness(operation);
auto getValueLivenessRange = [&](Value value) {
auto getValueLivenessRange = [&](Value value, BufferT *buffer) {
auto liveOperations = liveness.resolveLiveness(value);
// Update regions for buffer.
std::for_each(liveOperations.begin(), liveOperations.end(),
[&](Operation *liveOp) {
for (auto rId : getAsyncTaskIds(liveOp)) {
buffer->regionIds.insert(rId);
}
});
auto minId = std::numeric_limits<size_t>::max();
auto maxId = std::numeric_limits<size_t>::min();
std::for_each(liveOperations.begin(), liveOperations.end(),
[&](Operation *liveOp) {
if (!getAsyncTaskIds(liveOp).empty()) {
if (buffer->regionIds.size() > 1) {
// For a buffer that is associated with warp
// specialization, due to producer-consumer channel, it
// should have at least two regions, and it will be live
// throughout. For a buffer that is local to a consumer:
// we need to make sure not to overlap with local
// buffers from another consumer. This will be handled
// when building the interference graph.
minId = 0;
maxId = operationId.size();
return;
Expand All @@ -440,6 +463,22 @@ class AllocationAnalysis {
resolveScratchBufferLiveness(operationId);
}

void dumpBuffers() {
LDBG("Dump bufferRange: id size offset ---------");
for (auto bufferIter : bufferRange) {
LLVM_DEBUG({
llvm::dbgs() << "-- " << bufferIter.first->id << " "
<< bufferIter.first->size << " "
<< bufferIter.first->offset << " regions [";
for (auto tId : bufferIter.first->regionIds) {
llvm::dbgs() << tId << " ";
}
llvm::dbgs() << "] interval " << bufferIter.second.start() << " "
<< bufferIter.second.end() << "\n";
});
}
}

/// Computes the shared memory offsets for all related values.
/// Paper: Algorithms for Compile-Time Memory Optimization
/// (https://dl.acm.org/doi/pdf/10.5555/314500.315082)
Expand All @@ -450,6 +489,7 @@ class AllocationAnalysis {
}

calculateStarts(buffers);
dumpBuffers();

// NOTE: The original paper doesn't consider interference between
// the bumped ranges. Buffers that previously do not interfere with
Expand All @@ -464,6 +504,7 @@ class AllocationAnalysis {
allocate(buffers, interference);
buildInterferenceGraph(buffers, interference);
} while (!interference.empty());
dumpBuffers();
}

/// Computes the initial shared memory offsets.
Expand Down Expand Up @@ -531,6 +572,19 @@ class AllocationAnalysis {
void buildInterferenceGraph(const SmallVector<BufferT *> &buffers,
GraphT &interference) {
// Reset interference graph
auto inDifferentRegion = [&](BufferT *A, BufferT *B) {
auto tA = A->regionIds;
auto tB = B->regionIds;
if (tA.empty() || tB.empty())
return true;
for (auto t1 : tA) {
for (auto t2 : tB) {
if (t1 != t2)
return true;
}
}
return false;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can one buffer have a region id while the other doesn't, and should that be treated in different regions?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can be conservative. I am currently trying to handle the private buffer associated with channels, the checking for "!= 0" i.e ignoring producer warp group is kind of hacky.

};
interference.clear();
for (auto x : buffers) {
for (auto y : buffers) {
Expand All @@ -548,6 +602,9 @@ class AllocationAnalysis {
xSizeRange.intersects(ySizeRange)) {
interference[x].insert(y);
}
// if x and y belong to different regions (ignore producer region).
if (inDifferentRegion(x, y) && xSizeRange.intersects(ySizeRange))
interference[x].insert(y);
}
}
}
Expand Down