From 1f25ad12d564662d41ad705ab570fbff936d115e Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Mon, 9 Dec 2024 11:01:44 -0800 Subject: [PATCH 1/4] WarpSpec] improve allocation for smem Summary: Attempt to teach Allocation analysis to be aware of warpspec regions. Add a list of regions to each buffer, also teach interference graph to be ware of regions. Currently it makes convert_layout within one consumer to be able to overlap and in the non-persistent case, convert_layout can share with private global buffer. For persistent, we need to make sure producer doesn't reload the private global buffer for the outer loop (i.e persistent loop) before convert_layout happens in the consumer. Test Plan: Run JFA bwd Reviewers: Subscribers: Tasks: Tags: --- include/triton/Analysis/Allocation.h | 1 + lib/Analysis/Allocation.cpp | 129 ++++++++++++++++++++++----- 2 files changed, 108 insertions(+), 22 deletions(-) diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 3a488e65ed..fcc05d6197 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -186,6 +186,7 @@ class Allocation { size_t size; size_t alignment; size_t offset; + SetVector regionIds; bool operator==(const BufferT &other) const { return id == other.id; } bool operator<(const BufferT &other) const { return id < other.id; } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index e53f70175f..99c3658e0e 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -30,6 +30,10 @@ using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; +#define DEBUG_TYPE "allocation-analysis" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir { //===----------------------------------------------------------------------===// @@ -330,11 +334,12 @@ class AllocationAnalysis { /// Computes the liveness range of the allocated value. /// Each buffer is allocated only once. void resolveExplicitBufferLiveness( - function_ref(Value value)> getLiveness) { + function_ref(Value value, BufferT *buffer)> + getLiveness) { for (auto valueBufferIter : allocation->valueBuffer) { auto value = valueBufferIter.first; auto *buffer = valueBufferIter.second; - bufferRange[buffer] = getLiveness(value); + bufferRange[buffer] = getLiveness(value, buffer); } } @@ -342,11 +347,12 @@ class AllocationAnalysis { /// values because each allocated buffer could be an alias of others, if block /// arguments are involved. void resolveAliasBufferLiveness( - function_ref(Value value)> getLiveness) { + function_ref(Value value, BufferT *buffer)> + getLiveness) { for (auto aliasBufferIter : allocation->aliasBuffer) { auto value = aliasBufferIter.first; auto buffers = aliasBufferIter.second; - auto range = getLiveness(value); + auto range = getLiveness(value, buffers.front()); for (auto *buffer : buffers) { auto minId = range.start(); auto maxId = range.end(); @@ -378,11 +384,14 @@ class AllocationAnalysis { bufferRange.insert({buffer, Interval(operationId.lookup(op), operationId.lookup(op) + 1)}); } else { - // FIXME: This range makes scratch buffers used in warp-specialized - // regions conflict with everything else in the program, which is - // too conservative, but safe. A better approach would make them - // conflict with buffers live in other warp-specialized regions. - bufferRange.insert({buffer, Interval(0, operationId.size())}); + for (auto tId : getAsyncTaskIds(op)) + buffer->regionIds.insert(tId); + // For warp-specialized code, we can assume each region has its own + // copy of a scratch buffer, i.e each region is for a single taskId. + // In that case, we don't need to extend the liveness of scratch + // buffers. + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); } } }; @@ -414,24 +423,59 @@ class AllocationAnalysis { // Analyze liveness of explicit buffers Liveness liveness(operation); - auto getValueLivenessRange = [&](Value value) { + auto getValueLivenessRange = [&](Value value, BufferT *buffer) { auto liveOperations = liveness.resolveLiveness(value); - auto minId = std::numeric_limits::max(); - auto maxId = std::numeric_limits::min(); + // Update regions for buffer. std::for_each(liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { - if (!getAsyncTaskIds(liveOp).empty()) { - minId = 0; - maxId = operationId.size(); - return; - } - if (operationId[liveOp] < minId) { - minId = operationId[liveOp]; - } - if ((operationId[liveOp] + 1) > maxId) { - maxId = operationId[liveOp] + 1; + for (auto rId : getAsyncTaskIds(liveOp)) { + buffer->regionIds.insert(rId); } }); + bool isSharedGlobalForWS = false, isPrivateGlobalForWS = false, + isLocalForWS = false; + // Check regions on buffer. + if (buffer->regionIds.size() == 1) + isLocalForWS = true; + if (buffer->regionIds.size() > 1) { + // Assume region 0 is producer. + if (buffer->regionIds.count(0) && buffer->regionIds.size() == 2) + isPrivateGlobalForWS = true; + else + isSharedGlobalForWS = true; + } + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each( + liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { + if (isSharedGlobalForWS) { //! getAsyncTaskIds(liveOp).empty()) { + // For a buffer that is associated with warp specialization, due + // to producer-consumer channel: + // We differentiate the case of buffers that are shared with + // multiple consumers vs. buffers that are private to one + // consumer. For the latter, we can start from 0 (due to producer + // in a different region) and end at the top-level op within the + // region. For the former, we need to cover the whole range of + // [0, operationId.size()), since we don't know execution of the + // other consumer. + // For a buffer that is local to a consumer: we need to make sure + // not to overlap with local buffers from another consumer. + minId = 0; + maxId = operationId.size(); + return; + } + if (isPrivateGlobalForWS) { + minId = 0; + maxId = operationId[liveOp] + 1 > maxId ? operationId[liveOp] + 1 + : maxId; + } + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); return Interval(minId, maxId); }; @@ -440,6 +484,31 @@ class AllocationAnalysis { resolveScratchBufferLiveness(operationId); } + void dumpBuffers() { + LDBG("Dump bufferRange ---------"); + for (auto bufferIter : bufferRange) { + LLVM_DEBUG({ + llvm::dbgs() << "-- " << bufferIter.first->size << " " << bufferIter.first->offset < " regions "; + for (auto tId : bufferIter.first->regionIds) { + llvm::dbgs() << tId << " "; + } + llvm::dbgs() << " interval " << bufferIter.second.start() << " " + << bufferIter.second.end() << "\n"; + }); + } + } + void printBuffers() { + llvm::errs() << "Dump bufferRange ---------" << "\n"; + for (auto bufferIter : bufferRange) { + llvm::errs() << "-- " << bufferIter.first->size << " " << bufferIter.first->offset << " regions "; + for (auto tId : bufferIter.first->regionIds) { + llvm::errs() << tId << " "; + } + llvm::errs() << " interval " << bufferIter.second.start() << " " + << bufferIter.second.end() << "\n"; + } + } + /// Computes the shared memory offsets for all related values. /// Paper: Algorithms for Compile-Time Memory Optimization /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) @@ -450,6 +519,7 @@ class AllocationAnalysis { } calculateStarts(buffers); + dumpBuffers(); // NOTE: The original paper doesn't consider interference between // the bumped ranges. Buffers that previously do not interfere with @@ -462,6 +532,7 @@ class AllocationAnalysis { buildInterferenceGraph(buffers, interference); do { allocate(buffers, interference); + dumpBuffers(); buildInterferenceGraph(buffers, interference); } while (!interference.empty()); } @@ -531,6 +602,17 @@ class AllocationAnalysis { void buildInterferenceGraph(const SmallVector &buffers, GraphT &interference) { // Reset interference graph + auto inDifferentRegion = [&](BufferT *A, BufferT *B) { + auto tA = A->regionIds; + auto tB = B->regionIds; + for (auto t1 : tA) { + for (auto t2 : tA) { + if (t1 != 0 && t2 != 0 && t1 != t2) + return true; + } + } + return false; + }; interference.clear(); for (auto x : buffers) { for (auto y : buffers) { @@ -548,6 +630,9 @@ class AllocationAnalysis { xSizeRange.intersects(ySizeRange)) { interference[x].insert(y); } + // if x and y belong to different regions (ignore producer region). + if (inDifferentRegion(x, y) && xSizeRange.intersects(yOpRange)) + interference[x].insert(y); } } } From 50c6be383299b8bd446d5e14906285f65a145703 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Mon, 9 Dec 2024 14:58:26 -0800 Subject: [PATCH 2/4] fix Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- lib/Analysis/Allocation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 99c3658e0e..9717951a0b 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -488,7 +488,7 @@ class AllocationAnalysis { LDBG("Dump bufferRange ---------"); for (auto bufferIter : bufferRange) { LLVM_DEBUG({ - llvm::dbgs() << "-- " << bufferIter.first->size << " " << bufferIter.first->offset < " regions "; + llvm::dbgs() << "-- " << bufferIter.first->size << " " << bufferIter.first->offset << " regions "; for (auto tId : bufferIter.first->regionIds) { llvm::dbgs() << tId << " "; } From 74c714f1917d9a822ab9971064ca80ae6219835c Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Wed, 11 Dec 2024 16:24:18 -0800 Subject: [PATCH 3/4] fix typos etc Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- lib/Analysis/Allocation.cpp | 90 +++++++++++++------------------------ 1 file changed, 30 insertions(+), 60 deletions(-) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 9717951a0b..8892444928 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -432,50 +432,29 @@ class AllocationAnalysis { buffer->regionIds.insert(rId); } }); - bool isSharedGlobalForWS = false, isPrivateGlobalForWS = false, - isLocalForWS = false; - // Check regions on buffer. - if (buffer->regionIds.size() == 1) - isLocalForWS = true; - if (buffer->regionIds.size() > 1) { - // Assume region 0 is producer. - if (buffer->regionIds.count(0) && buffer->regionIds.size() == 2) - isPrivateGlobalForWS = true; - else - isSharedGlobalForWS = true; - } auto minId = std::numeric_limits::max(); auto maxId = std::numeric_limits::min(); - std::for_each( - liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { - if (isSharedGlobalForWS) { //! getAsyncTaskIds(liveOp).empty()) { - // For a buffer that is associated with warp specialization, due - // to producer-consumer channel: - // We differentiate the case of buffers that are shared with - // multiple consumers vs. buffers that are private to one - // consumer. For the latter, we can start from 0 (due to producer - // in a different region) and end at the top-level op within the - // region. For the former, we need to cover the whole range of - // [0, operationId.size()), since we don't know execution of the - // other consumer. - // For a buffer that is local to a consumer: we need to make sure - // not to overlap with local buffers from another consumer. - minId = 0; - maxId = operationId.size(); - return; - } - if (isPrivateGlobalForWS) { - minId = 0; - maxId = operationId[liveOp] + 1 > maxId ? operationId[liveOp] + 1 - : maxId; - } - if (operationId[liveOp] < minId) { - minId = operationId[liveOp]; - } - if ((operationId[liveOp] + 1) > maxId) { - maxId = operationId[liveOp] + 1; - } - }); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (buffer->regionIds.size() > 1) { + // For a buffer that is associated with warp + // specialization, due to producer-consumer channel, it + // should have at least two regions, and it will be live + // throughout. For a buffer that is local to a consumer: + // we need to make sure not to overlap with local + // buffers from another consumer. This will be handled + // when building the interference graph. + minId = 0; + maxId = operationId.size(); + return; + } + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); return Interval(minId, maxId); }; @@ -485,29 +464,20 @@ class AllocationAnalysis { } void dumpBuffers() { - LDBG("Dump bufferRange ---------"); + LDBG("Dump bufferRange: id size offset ---------"); for (auto bufferIter : bufferRange) { LLVM_DEBUG({ - llvm::dbgs() << "-- " << bufferIter.first->size << " " << bufferIter.first->offset << " regions "; + llvm::dbgs() << "-- " << bufferIter.first->id << " " + << bufferIter.first->size << " " + << bufferIter.first->offset << " regions ["; for (auto tId : bufferIter.first->regionIds) { llvm::dbgs() << tId << " "; } - llvm::dbgs() << " interval " << bufferIter.second.start() << " " + llvm::dbgs() << "] interval " << bufferIter.second.start() << " " << bufferIter.second.end() << "\n"; }); } } - void printBuffers() { - llvm::errs() << "Dump bufferRange ---------" << "\n"; - for (auto bufferIter : bufferRange) { - llvm::errs() << "-- " << bufferIter.first->size << " " << bufferIter.first->offset << " regions "; - for (auto tId : bufferIter.first->regionIds) { - llvm::errs() << tId << " "; - } - llvm::errs() << " interval " << bufferIter.second.start() << " " - << bufferIter.second.end() << "\n"; - } - } /// Computes the shared memory offsets for all related values. /// Paper: Algorithms for Compile-Time Memory Optimization @@ -532,9 +502,9 @@ class AllocationAnalysis { buildInterferenceGraph(buffers, interference); do { allocate(buffers, interference); - dumpBuffers(); buildInterferenceGraph(buffers, interference); } while (!interference.empty()); + dumpBuffers(); } /// Computes the initial shared memory offsets. @@ -606,8 +576,8 @@ class AllocationAnalysis { auto tA = A->regionIds; auto tB = B->regionIds; for (auto t1 : tA) { - for (auto t2 : tA) { - if (t1 != 0 && t2 != 0 && t1 != t2) + for (auto t2 : tB) { + if (t1 != t2) return true; } } @@ -631,7 +601,7 @@ class AllocationAnalysis { interference[x].insert(y); } // if x and y belong to different regions (ignore producer region). - if (inDifferentRegion(x, y) && xSizeRange.intersects(yOpRange)) + if (inDifferentRegion(x, y) && xSizeRange.intersects(ySizeRange)) interference[x].insert(y); } } From 14e46c83da94bce0cc552ac24a73d89b54f96872 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Wed, 11 Dec 2024 16:29:50 -0800 Subject: [PATCH 4/4] be conservative Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- lib/Analysis/Allocation.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 8892444928..62f315f46e 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -575,6 +575,8 @@ class AllocationAnalysis { auto inDifferentRegion = [&](BufferT *A, BufferT *B) { auto tA = A->regionIds; auto tB = B->regionIds; + if (tA.empty() || tB.empty()) + return true; for (auto t1 : tA) { for (auto t2 : tB) { if (t1 != t2)