Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ struct ReduceOpConversion
}
}

// Calculates the write index in the shared memory where we would be writing
// the within-thread accumulations before we start doing across-threads
// accumulations. `index` is the index of the within-thread accumulations in
// the full tensor, whereas `writeIdx` is the mapped-to index in the shared
// memory
void getWriteIndexBasic(ConversionPatternRewriter &rewriter, Location loc,
Attribute layout, SmallVector<Value> &index,
SmallVector<Value> &writeIdx,
Expand All @@ -141,6 +146,12 @@ struct ReduceOpConversion
Value _8 = ints[8];
Value _16 = ints[16];
if (layout.isa<BlockedEncodingAttr>()) {
// A single thread owns axisSizePerThread contiguous values
// on the reduction axis. After within thread reduction,
// we would have a single accumulation every `axisSizePerThread`
// contiguous values in the original tensor, so we would need
// to map every `axisSizePerThread` to 1 value in smem as:
// writeIdx[axis] = index[axis] / axisSizePerThread
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
Expand All @@ -152,9 +163,7 @@ struct ReduceOpConversion
writeIdx[axis] =
add(mul(udiv(index[axis], _16), _8), urem(index[axis], _8));
} else {
// A single thread owns axisSizePerThread contiguous values
// on the reduction axis, so after within thread reduction,
// writeIdx[axis] = index[axis] / axisSizePerThread
// Same as BlockedEncodingAttr case
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
}
Expand All @@ -170,13 +179,16 @@ struct ReduceOpConversion
ReduceOpHelper helper(op);
Location loc = op->getLoc();
unsigned axis = op.getAxis();
// Specifies whether the reduce operation returns an index
// rather than a value, e.g. argmax, argmin, .. etc
bool withIndex = triton::ReduceOp::withIndex(op.getRedOp());

auto srcTy = op.getOperand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
if (!helper.isSupportedLayout()) {
assert(false && "Unexpected srcLayout in ReduceOpConversion");
}
// The order of the axes for the the threads within the warp
auto srcOrd = triton::gpu::getOrder(srcLayout);
auto sizePerThread = triton::gpu::getSizePerThread(srcLayout);
auto srcShape = srcTy.getShape();
Expand All @@ -185,6 +197,7 @@ struct ReduceOpConversion
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);

Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);

Expand All @@ -194,13 +207,16 @@ struct ReduceOpConversion
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);

unsigned srcElems = getElemsPerThread(srcTy);
// Emits indices of the original tensor that each thread
// would own
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
auto srcValues = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperand(), rewriter, srcTy);

// Emits offsets (the offset from the base index)
// of the original tensor that each thread would own
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcTy);

// Keep track of accumulations and their indices
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
Expand Down Expand Up @@ -238,29 +254,49 @@ struct ReduceOpConversion
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
// get the writeIdx at which to write in smem
SmallVector<Value> writeIdx;
getWriteIndexBasic(rewriter, loc, srcLayout, indices[key], writeIdx, ints,
axis);
// calculate the offset in smem for that writeIdx
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
// Get element pointers for the value and index
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
// Store the within-thread accumulated value at writePtr
store(acc, writePtr);
// Store the index of within-thread accumulation at indexWritePtr
if (withIndex)
store(accIndex, indexWritePtr);

SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
// Perform parallel reduction with sequential addressing
// E.g. We reduce `smemShape[axis]` elements into `smemShape[axis]/2`
// elements using `smemShape[axis]/2` threads where each thread
// would accumalte values that are `smemShape[axis]/2` apart
// to avoid bank conflicts. Then we repeat with `smemShape[axis]/4`
// threads, .. etc.
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
// The readIdx will be N elements away on the reduction axis
readIdx[axis] = ints[N];
// If the writeIdx is greater or equal to N, do nothing
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
// Calculate the readOffset, if readMask is False, readOffset=0
// meaning we reduce the value at writeIdx with itself
Value readOffset = select(
readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
ints[0]);
// The readPtr is readOffset away from writePtr
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
barrier();
// If we do not care about the index, i.e. this is not an argmax,
// argmin, .. etc
if (!withIndex) {
// The value at the readPtr, whereas acc is the value at writePtr
Value cur = load(readPtr);
accumulate(rewriter, loc, op.getRedOp(), acc, cur, false);
barrier();
// Update writePtr value
store(acc, writePtr);
} else {
Value cur = load(readPtr);
Expand Down