diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 30de13f6a88c..16c9991a17b0 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,10 +1,7 @@ #include "ReduceScanCommon.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include using namespace mlir; using namespace mlir::triton; @@ -80,36 +77,16 @@ struct ReduceOpConversion private: const TargetInfoBase &targetInfo; - void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - SmallVector &acc, ValueRange cur, bool isFirst) const { - if (isFirst) { - acc = SmallVector(cur.begin(), cur.end()); - return; + void accumulate(Location loc, ConversionPatternRewriter &rewriter, + Region &combineOp, SmallVector &acc, ValueRange cur, + Value pred = {}) const { + auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); + if (acc.size() < results.size()) { + acc.resize(results.size()); } - - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newReduce = parent.front(); - auto returnOp = dyn_cast(newReduce.getTerminator()); - - llvm::SmallVector combineArgs(2 * acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; - } - - rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), - combineArgs); - - auto results = returnOp.getResult(); for (unsigned i = 0; i < acc.size(); ++i) { acc[i] = results[i]; } - - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); } SmallVector> @@ -165,7 +142,7 @@ struct ReduceOpConversion SmallVector key = offsets[i]; key[op.getAxis()] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); + accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]); if (isFirst) indices[key] = srcIndices[i]; } @@ -175,17 +152,29 @@ struct ReduceOpConversion // region and the accumulator values as source. void warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce, unsigned interleave) const { + unsigned numLaneToReduce, unsigned interleave, + Value pred = {}) const { auto success = targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce, interleave); if (success) return; + + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (iWarpSize > numLaneToReduce) { + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(iWarpSize); + Value laneId = urem(threadId, warpSize); + Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); + pred = pred ? and_(pred, lanePred) : lanePred; + } + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { SmallVector shfl(acc.size()); for (unsigned i = 0; i < acc.size(); ++i) { shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); } - accumulate(rewriter, op.getCombineOp(), acc, shfl, false); + accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred); } } @@ -344,7 +333,8 @@ struct ReduceOpConversion acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, threadIsNeeded); } - warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */, + threadIsNeeded); // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; SmallVector writePtrs(op.getNumOperands()); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h index 3130001cc5d4..9f823e2e135a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -4,15 +4,14 @@ // TODO: refactor so that it doesn't fail if Allocation.h // is included after utility.h (due to conflict in `store` macro // and -#include "triton/Analysis/Allocation.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" // #include "mlir/IR/TypeUtilities.h" -#include "triton/Analysis/AxisInfo.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include +#include #include #define DEBUG_TYPE "ttgpu_to_llvm" @@ -32,6 +31,91 @@ namespace ttng = ::mlir::triton::nvidia_gpu; namespace mlir::triton { class ReduceOp; class ScanOp; + +inline SmallVector +inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock, + Block *insertionBlock, Block::iterator insertionPoint, + ValueRange combineArgs) { + auto returnOp = combineBlock.getTerminator(); + rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint, + combineArgs); + + auto results = SmallVector(returnOp->getOperands()); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +inline SmallVector applyCombineOp(Location loc, + ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur, Value pred = {}) { + // Allows for passing an unitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + + // Create a new copy of the combine block, and try to speculatively inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + + rewriter.cloneRegionBefore(combineOp, parent, + std::next(currentBlock->getIterator())); + Block &newCombine = *currentBlock->getNextNode(); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + auto isRegionSpeculatable = + std::all_of(newCombine.begin(), newCombine.end(), + [](auto &op) { return isSpeculatable(&op); }); + + if (!pred || isRegionSpeculatable) { + // Fast path, region has no side effects so we can unconditionally execute + return inlineCombineBlock(rewriter, newCombine, currentBlock, + rewriter.getInsertionPoint(), combineArgs); + } + + // Slow case, create an if to only execute region when pred is true + // #currentBlock + // if (pred) { + // #newCombine + // results = combineOp(cur, acc) + // yield results + // } else { + // yield undef + // } + // #thenBlock + Block *thenBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + auto returnOp = newCombine.getTerminator(); + auto results = SmallVector(returnOp->getOperands()); + + rewriter.setInsertionPointToEnd(currentBlock); + SmallVector thenBlockArgs; + thenBlockArgs.reserve(results.size()); + for (auto result : results) { + auto ty = result.getType(); + auto undef = rewriter.create(loc, ty); + thenBlockArgs.push_back(undef); + thenBlock->addArgument(ty, loc); + } + rewriter.create(loc, pred, &newCombine, combineArgs, + thenBlock, thenBlockArgs); + + // Split a block after the call. + rewriter.setInsertionPointToEnd(&newCombine); + rewriter.replaceOpWithNewOp(returnOp, thenBlock, results); + rewriter.setInsertionPointToStart(thenBlock); + return SmallVector(thenBlock->getArguments()); +} + } // namespace mlir::triton template diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 675bf5a34213..b07f654a1eaf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -1,5 +1,3 @@ -#include - #include "ReduceScanCommon.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" @@ -16,37 +14,13 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getTotalElemsPerThread; // apply combine region to acc and cur and accumulate it into acc -// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce. -// Deduplicate -static SmallVector accumulate(ConversionPatternRewriter &rewriter, - Region &combineOp, ValueRange acc, - ValueRange cur) { - // Allows for passing an unitialized acc and use cur as the neutral element - if (acc.size() == 0) { - return cur; - } - assert(cur.size() == acc.size()); - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newScan = parent.front(); - auto returnOp = dyn_cast(newScan.getTerminator()); - - SmallVector combineArgs(2 * acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; - } - - rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), - combineArgs); - SmallVector results; - llvm::transform(returnOp.getResult(), std::back_inserter(results), - [&](Value res) { return rewriter.getRemappedValue(res); }); - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); - return results; +static SmallVector accumulate(ScanLoweringHelper &helper, + ConversionPatternRewriter &rewriter, + ValueRange acc, ValueRange cur, + Value pred = {}) { + auto loc = helper.getLoc(); + auto &combineOp = helper.getCombineOp(); + return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); } // Scan a contiguous elements within a thread and update `srcValues` in place. @@ -66,8 +40,8 @@ scanThreadContiguousElements(SmallVector> &srcValues, unsigned accIndex = (srcIndex % stride) + ((srcIndex / stride) / scanElementsPerThreads) * stride; - accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], - srcValues[srcIndex]); + accs[accIndex] = + accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]); srcValues[srcIndex] = accs[accIndex]; } } @@ -95,11 +69,11 @@ static void warpScan(SmallVector> &srcValues, for (unsigned j = 0; j < acc.size(); ++j) { shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); } + Value mask = icmp_sge(laneIdAxis, i32_val(i)); SmallVector tempAcc = - accumulate(rewriter, helper.getCombineOp(), shfl, acc); - Value mask = icmp_slt(laneIdAxis, i32_val(i)); + accumulate(helper, rewriter, shfl, acc, mask); for (unsigned j = 0; j < acc.size(); ++j) { - acc[j] = select(mask, acc[j], tempAcc[j]); + acc[j] = select(mask, tempAcc[j], acc[j]); } } srcValues[srcIndex] = acc; @@ -164,9 +138,9 @@ static void AddPartialReduce(SmallVector> &srcValues, unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); - Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); - Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); - Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + Value maskNotFirstWarp = icmp_ne(warpId, i32_val(0)); + Value maskNotFirstLane = icmp_ne(laneIdAxis, i32_val(0)); + Value maskNotFirstThread = or_(maskNotFirstWarp, maskNotFirstLane); struct Accumulator { SmallVector acc; SmallVector maskedAcc; @@ -212,22 +186,24 @@ static void AddPartialReduce(SmallVector> &srcValues, accumulator.maskedAcc = partialReduce; continue; } - accumulator.acc = accumulate(rewriter, helper.getCombineOp(), - accumulator.acc, partialReduce); - Value mask = icmp_slt(warpId, i32_val(i + 1)); + Value mask = icmp_sge(warpId, i32_val(i + 1)); + accumulator.acc = + accumulate(helper, rewriter, accumulator.acc, partialReduce, mask); for (unsigned j = 0; j < helper.getNumOperands(); ++j) { accumulator.maskedAcc[j] = - select(mask, accumulator.maskedAcc[j], accumulator.acc[j]); + select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); } } - auto temp = accumulate(rewriter, helper.getCombineOp(), - accumulator.maskedAcc, srcValues[srcIndex]); + + Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{}; + auto temp = accumulate(helper, rewriter, accumulator.maskedAcc, + srcValues[srcIndex], pred); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. auto val = srcValues[srcIndex]; for (unsigned i = 0; i < helper.getNumOperands(); ++i) { - temp[i] = select(maskFirstWarp, val[i], temp[i]); + temp[i] = select(maskNotFirstWarp, temp[i], val[i]); } } srcValues[srcIndex] = temp; @@ -235,19 +211,18 @@ static void AddPartialReduce(SmallVector> &srcValues, SmallVector lastElement(helper.getNumOperands()); for (unsigned i = 0; i < helper.getNumOperands(); ++i) { auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); - lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem); + lastElement[i] = select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + pred = axisBlockId == 0 ? maskNotFirstThread : Value{}; auto laneValue = srcValues[srcIndex - i * elementStride]; - laneValue = - accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. for (unsigned j = 0; j < helper.getNumOperands(); ++j) { - laneValue[j] = - select(maskFirstThread, - srcValues[srcIndex - i * elementStride][j], laneValue[j]); + laneValue[j] = select(maskNotFirstThread, laneValue[j], + srcValues[srcIndex - i * elementStride][j]); } } srcValues[srcIndex - i * elementStride] = laneValue; @@ -300,8 +275,8 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, if (axisBlockId == 0) // First chunk and first block accumulator = srcValues[srcIndex]; else - srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), - accumulator, srcValues[srcIndex]); + srcValues[srcIndex] = + accumulate(helper, rewriter, accumulator, srcValues[srcIndex]); // Update the rest of the contiguous elements. auto lastElement = srcValues[srcIndex]; if (scanDim > 1) { @@ -319,8 +294,7 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { auto laneValue = srcValues[srcIndex - i * elementStride]; - laneValue = - accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + laneValue = accumulate(helper, rewriter, lastElement, laneValue); if (axisBlockId == 0) { for (unsigned j = 0; j < helper.getNumOperands(); ++j) { // For the first warp and first chunk we don't have anything to diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d84788d9d653..039f7ac1ac4f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5644,3 +5644,52 @@ def check_loop_unroll_count(ir, opStr, loop_unroll_factor): for unroll_factor in [1, 2, 4, 5, 8]: h = _kernel[(1, )](torch.empty(1, device=device), unroll_factor) check_loop_unroll_count(h.asm["ttir"], 'tt.atomic_rmw', unroll_factor) + + +@triton.jit +def sanitize_add(a, b): + a64 = a.to(tl.int64) + b64 = b.to(tl.int64) + r64 = a64 + b64 + tl.device_assert((r64 >= -2**31) & (r64 <= 2**31 - 1)) + return a + b + + +def test_side_effectful_reduction(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.reduce(vals, 0, sanitize_add) + tl.store(Z, z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros((), device="cuda", dtype=torch.int32) + sanitize_sum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.sum().to(torch.int32)) + + +def test_side_effectful_scan(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.associative_scan(vals, 0, sanitize_add) + tl.store(Z + tl.arange(0, BLOCK), z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros_like(X) + sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32))