Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ class Allocation {
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}

size_t setOffsetAligned(size_t newOffset) {
Comment thread
Jokeren marked this conversation as resolved.
return offset = llvm::alignTo(newOffset, alignment);
}
};

/// Op -> Scratch Buffer
Expand Down
40 changes: 17 additions & 23 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,7 @@ class AllocationAnalysis {
buffers.emplace_back(bufferIter.first);
}

DenseMap<BufferT *, size_t> bufferStart;
calculateStarts(buffers, bufferStart);
calculateStarts(buffers);

// NOTE: The original paper doesn't consider interference between
// the bumped ranges. Buffers that previously do not interfere with
Expand All @@ -494,16 +493,15 @@ class AllocationAnalysis {
// increase the buffer offset and keep reducing conflicts, we will
// eventually reach a fixed point.
GraphT interference;
buildInterferenceGraph(buffers, bufferStart, interference);
buildInterferenceGraph(buffers, interference);
do {
allocate(buffers, interference, bufferStart);
buildInterferenceGraph(buffers, bufferStart, interference);
allocate(buffers, interference);
buildInterferenceGraph(buffers, interference);
} while (!interference.empty());
}

/// Computes the initial shared memory offsets.
void calculateStarts(const SmallVector<BufferT *> &buffers,
DenseMap<BufferT *, size_t> &bufferStart) {
void calculateStarts(const SmallVector<BufferT *> &buffers) {
// v = values in shared memory
// t = triplet of (size, start, end)
// shared memory space
Expand All @@ -527,7 +525,7 @@ class AllocationAnalysis {
SmallVector<BufferT *> xBuffers = buffers;
while (!xBuffers.empty()) {
auto tripleIt = tripleMap.begin();
auto size = tripleIt->first;
auto offset = tripleIt->first;
auto range = tripleIt->second;
tripleMap.erase(tripleIt);
auto bufferIt =
Expand All @@ -545,20 +543,18 @@ class AllocationAnalysis {
auto xRange = bufferRange.lookup(buffer);
// TODO(Keren): A buffer's size shouldn't be determined here, have to
// clean it up
size_t alignment = buffer->alignment;
size_t alignSize = ((size + alignment - 1) / alignment) * alignment;
bufferStart[buffer] = alignSize;
tripleMap.insert({alignSize + xSize,
size_t alignOffset = buffer->setOffsetAligned(offset);
tripleMap.insert({alignOffset + xSize,
Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
// We could either insert (range.start, xRange.start) or (range.start,
// xRange.end), both are correct and determine the potential buffer
// offset, and the graph coloring algorithm will solve the interference,
// if any
if (range.start() < xRange.start())
tripleMap.insert({size, Interval{range.start(), xRange.end()}});
tripleMap.insert({offset, Interval{range.start(), xRange.end()}});
if (xRange.end() < range.end())
tripleMap.insert({size, Interval{xRange.start(), range.end()}});
tripleMap.insert({offset, Interval{xRange.start(), range.end()}});
xBuffers.erase(bufferIt);
}
}
Expand All @@ -567,16 +563,15 @@ class AllocationAnalysis {
/// Builds a graph of all shared memory values. Edges are created between
/// shared memory values that are overlapping.
void buildInterferenceGraph(const SmallVector<BufferT *> &buffers,
const DenseMap<BufferT *, size_t> &bufferStart,
GraphT &interference) {
// Reset interference graph
interference.clear();
for (auto x : buffers) {
for (auto y : buffers) {
if (x == y)
continue;
auto xStart = bufferStart.lookup(x);
auto yStart = bufferStart.lookup(y);
auto xStart = x->offset;
auto yStart = y->offset;
auto xSize = x->size;
auto ySize = y->size;
Interval xSizeRange = {xStart, xStart + xSize};
Expand All @@ -593,8 +588,7 @@ class AllocationAnalysis {

/// Finalizes shared memory offsets considering interference.
void allocate(const SmallVector<BufferT *> &buffers,
const GraphT &interference,
DenseMap<BufferT *, size_t> &bufferStart) {
const GraphT &interference) {
// Reset shared memory size
allocation->sharedMemorySize = 0;
// First-fit graph coloring
Expand Down Expand Up @@ -625,12 +619,12 @@ class AllocationAnalysis {
// TODO(Keren): We are wasting memory here.
// Nodes with color2 can actually start with 24.
for (auto x : buffers) {
size_t adj = 0;
size_t newOffset = 0;
for (auto y : interference.lookup(x)) {
adj = std::max(adj, bufferStart.lookup(y) + y->size);
newOffset = std::max(newOffset, y->offset + y->size);
}
x->offset = bufferStart.lookup(x) + colors.lookup(x) * adj;
bufferStart[x] = x->offset;
if (colors.lookup(x) != 0)
x->setOffsetAligned(newOffset);
allocation->sharedMemorySize =
std::max(allocation->sharedMemorySize, x->offset + x->size);
}
Expand Down
4 changes: 2 additions & 2 deletions test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
%5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL>
// CHECK-NEXT: offset = 1024, size = 512
%cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 3104, size = 128
// CHECK-NEXT: offset = 1792, size = 128
%cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED>
%6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL>
// CHECK-NEXT: offset = 1024, size = 512
Expand All @@ -217,7 +217,7 @@ tt.func @multi_color(%A : !tt.ptr<f16>) {
%10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED> -> tensor<2x32xf16, #AL>
%cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL>
%cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL>
// CHECK-NEXT: size = 3232
// CHECK-NEXT: size = 1920
tt.return
}

Expand Down