diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index cb71f34319b8..a9e02b420844 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -160,6 +160,10 @@ class Allocation { size_t offset = 0) : kind(kind), id(nextId++), size(size), alignment(alignment), offset(offset) {} + + size_t setOffsetAligned(size_t newOffset) { + return offset = llvm::alignTo(newOffset, alignment); + } }; /// Op -> Scratch Buffer diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index bed37bbbf779..1e6e38749f4c 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -483,8 +483,7 @@ class AllocationAnalysis { buffers.emplace_back(bufferIter.first); } - DenseMap bufferStart; - calculateStarts(buffers, bufferStart); + calculateStarts(buffers); // NOTE: The original paper doesn't consider interference between // the bumped ranges. Buffers that previously do not interfere with @@ -494,16 +493,15 @@ class AllocationAnalysis { // increase the buffer offset and keep reducing conflicts, we will // eventually reach a fixed point. GraphT interference; - buildInterferenceGraph(buffers, bufferStart, interference); + buildInterferenceGraph(buffers, interference); do { - allocate(buffers, interference, bufferStart); - buildInterferenceGraph(buffers, bufferStart, interference); + allocate(buffers, interference); + buildInterferenceGraph(buffers, interference); } while (!interference.empty()); } /// Computes the initial shared memory offsets. - void calculateStarts(const SmallVector &buffers, - DenseMap &bufferStart) { + void calculateStarts(const SmallVector &buffers) { // v = values in shared memory // t = triplet of (size, start, end) // shared memory space @@ -527,7 +525,7 @@ class AllocationAnalysis { SmallVector xBuffers = buffers; while (!xBuffers.empty()) { auto tripleIt = tripleMap.begin(); - auto size = tripleIt->first; + auto offset = tripleIt->first; auto range = tripleIt->second; tripleMap.erase(tripleIt); auto bufferIt = @@ -545,10 +543,8 @@ class AllocationAnalysis { auto xRange = bufferRange.lookup(buffer); // TODO(Keren): A buffer's size shouldn't be determined here, have to // clean it up - size_t alignment = buffer->alignment; - size_t alignSize = ((size + alignment - 1) / alignment) * alignment; - bufferStart[buffer] = alignSize; - tripleMap.insert({alignSize + xSize, + size_t alignOffset = buffer->setOffsetAligned(offset); + tripleMap.insert({alignOffset + xSize, Interval{std::max(range.start(), xRange.start()), std::min(range.end(), xRange.end())}}); // We could either insert (range.start, xRange.start) or (range.start, @@ -556,9 +552,9 @@ class AllocationAnalysis { // offset, and the graph coloring algorithm will solve the interference, // if any if (range.start() < xRange.start()) - tripleMap.insert({size, Interval{range.start(), xRange.end()}}); + tripleMap.insert({offset, Interval{range.start(), xRange.end()}}); if (xRange.end() < range.end()) - tripleMap.insert({size, Interval{xRange.start(), range.end()}}); + tripleMap.insert({offset, Interval{xRange.start(), range.end()}}); xBuffers.erase(bufferIt); } } @@ -567,7 +563,6 @@ class AllocationAnalysis { /// Builds a graph of all shared memory values. Edges are created between /// shared memory values that are overlapping. void buildInterferenceGraph(const SmallVector &buffers, - const DenseMap &bufferStart, GraphT &interference) { // Reset interference graph interference.clear(); @@ -575,8 +570,8 @@ class AllocationAnalysis { for (auto y : buffers) { if (x == y) continue; - auto xStart = bufferStart.lookup(x); - auto yStart = bufferStart.lookup(y); + auto xStart = x->offset; + auto yStart = y->offset; auto xSize = x->size; auto ySize = y->size; Interval xSizeRange = {xStart, xStart + xSize}; @@ -593,8 +588,7 @@ class AllocationAnalysis { /// Finalizes shared memory offsets considering interference. void allocate(const SmallVector &buffers, - const GraphT &interference, - DenseMap &bufferStart) { + const GraphT &interference) { // Reset shared memory size allocation->sharedMemorySize = 0; // First-fit graph coloring @@ -625,12 +619,12 @@ class AllocationAnalysis { // TODO(Keren): We are wasting memory here. // Nodes with color2 can actually start with 24. for (auto x : buffers) { - size_t adj = 0; + size_t newOffset = 0; for (auto y : interference.lookup(x)) { - adj = std::max(adj, bufferStart.lookup(y) + y->size); + newOffset = std::max(newOffset, y->offset + y->size); } - x->offset = bufferStart.lookup(x) + colors.lookup(x) * adj; - bufferStart[x] = x->offset; + if (colors.lookup(x) != 0) + x->setOffsetAligned(newOffset); allocation->sharedMemorySize = std::max(allocation->sharedMemorySize, x->offset + x->size); } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 2107fc754a18..738ad11b344a 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -200,7 +200,7 @@ tt.func @multi_color(%A : !tt.ptr) { %5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 %cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED> - // CHECK-NEXT: offset = 3104, size = 128 + // CHECK-NEXT: offset = 1792, size = 128 %cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED> %6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 @@ -217,7 +217,7 @@ tt.func @multi_color(%A : !tt.ptr) { %10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED> -> tensor<2x32xf16, #AL> %cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL> %cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL> - // CHECK-NEXT: size = 3232 + // CHECK-NEXT: size = 1920 tt.return }