From f4a565ce6678f649e04d5cc8daf34703b825739c Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 24 Sep 2024 09:00:45 -0700 Subject: [PATCH] [Backend] Fix device assert inside reduction/scan region Currently the reduction codegen unconditionally executes the combine region which can create problems because we conditionally load from shared memory, so this uses uninitialized registers. Generally combine regions should be pure, so this shouldn't be observable but with the overflow sanitizer the frontend injects assertions into the combine region. This changes the `accumulate` function to take a predicate and if the combine region isn't speculateble we only run it on threads where the predicate is true. In the common case, the codegen is unchanged. --- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 56 +++++------ .../TritonGPUToLLVM/ReduceScanCommon.h | 94 ++++++++++++++++++- .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 90 +++++++----------- python/test/unit/language/test_core.py | 49 ++++++++++ 4 files changed, 193 insertions(+), 96 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 30de13f6a88c..16c9991a17b0 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,10 +1,7 @@ #include "ReduceScanCommon.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include using namespace mlir; using namespace mlir::triton; @@ -80,36 +77,16 @@ struct ReduceOpConversion private: const TargetInfoBase &targetInfo; - void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - SmallVector &acc, ValueRange cur, bool isFirst) const { - if (isFirst) { - acc = SmallVector(cur.begin(), cur.end()); - return; + void accumulate(Location loc, ConversionPatternRewriter &rewriter, + Region &combineOp, SmallVector &acc, ValueRange cur, + Value pred = {}) const { + auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); + if (acc.size() < results.size()) { + acc.resize(results.size()); } - - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newReduce = parent.front(); - auto returnOp = dyn_cast(newReduce.getTerminator()); - - llvm::SmallVector combineArgs(2 * acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; - } - - rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), - combineArgs); - - auto results = returnOp.getResult(); for (unsigned i = 0; i < acc.size(); ++i) { acc[i] = results[i]; } - - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); } SmallVector> @@ -165,7 +142,7 @@ struct ReduceOpConversion SmallVector key = offsets[i]; key[op.getAxis()] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); + accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]); if (isFirst) indices[key] = srcIndices[i]; } @@ -175,17 +152,29 @@ struct ReduceOpConversion // region and the accumulator values as source. void warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce, unsigned interleave) const { + unsigned numLaneToReduce, unsigned interleave, + Value pred = {}) const { auto success = targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce, interleave); if (success) return; + + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (iWarpSize > numLaneToReduce) { + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(iWarpSize); + Value laneId = urem(threadId, warpSize); + Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); + pred = pred ? and_(pred, lanePred) : lanePred; + } + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { SmallVector shfl(acc.size()); for (unsigned i = 0; i < acc.size(); ++i) { shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); } - accumulate(rewriter, op.getCombineOp(), acc, shfl, false); + accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred); } } @@ -344,7 +333,8 @@ struct ReduceOpConversion acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, threadIsNeeded); } - warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */, + threadIsNeeded); // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; SmallVector writePtrs(op.getNumOperands()); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h index 3130001cc5d4..9f823e2e135a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -4,15 +4,14 @@ // TODO: refactor so that it doesn't fail if Allocation.h // is included after utility.h (due to conflict in `store` macro // and -#include "triton/Analysis/Allocation.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" // #include "mlir/IR/TypeUtilities.h" -#include "triton/Analysis/AxisInfo.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include +#include #include #define DEBUG_TYPE "ttgpu_to_llvm" @@ -32,6 +31,91 @@ namespace ttng = ::mlir::triton::nvidia_gpu; namespace mlir::triton { class ReduceOp; class ScanOp; + +inline SmallVector +inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock, + Block *insertionBlock, Block::iterator insertionPoint, + ValueRange combineArgs) { + auto returnOp = combineBlock.getTerminator(); + rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint, + combineArgs); + + auto results = SmallVector(returnOp->getOperands()); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +inline SmallVector applyCombineOp(Location loc, + ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur, Value pred = {}) { + // Allows for passing an unitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + + // Create a new copy of the combine block, and try to speculatively inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + + rewriter.cloneRegionBefore(combineOp, parent, + std::next(currentBlock->getIterator())); + Block &newCombine = *currentBlock->getNextNode(); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + auto isRegionSpeculatable = + std::all_of(newCombine.begin(), newCombine.end(), + [](auto &op) { return isSpeculatable(&op); }); + + if (!pred || isRegionSpeculatable) { + // Fast path, region has no side effects so we can unconditionally execute + return inlineCombineBlock(rewriter, newCombine, currentBlock, + rewriter.getInsertionPoint(), combineArgs); + } + + // Slow case, create an if to only execute region when pred is true + // #currentBlock + // if (pred) { + // #newCombine + // results = combineOp(cur, acc) + // yield results + // } else { + // yield undef + // } + // #thenBlock + Block *thenBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + auto returnOp = newCombine.getTerminator(); + auto results = SmallVector(returnOp->getOperands()); + + rewriter.setInsertionPointToEnd(currentBlock); + SmallVector thenBlockArgs; + thenBlockArgs.reserve(results.size()); + for (auto result : results) { + auto ty = result.getType(); + auto undef = rewriter.create(loc, ty); + thenBlockArgs.push_back(undef); + thenBlock->addArgument(ty, loc); + } + rewriter.create(loc, pred, &newCombine, combineArgs, + thenBlock, thenBlockArgs); + + // Split a block after the call. + rewriter.setInsertionPointToEnd(&newCombine); + rewriter.replaceOpWithNewOp(returnOp, thenBlock, results); + rewriter.setInsertionPointToStart(thenBlock); + return SmallVector(thenBlock->getArguments()); +} + } // namespace mlir::triton template diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 675bf5a34213..b07f654a1eaf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -1,5 +1,3 @@ -#include - #include "ReduceScanCommon.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" @@ -16,37 +14,13 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getTotalElemsPerThread; // apply combine region to acc and cur and accumulate it into acc -// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce. -// Deduplicate -static SmallVector accumulate(ConversionPatternRewriter &rewriter, - Region &combineOp, ValueRange acc, - ValueRange cur) { - // Allows for passing an unitialized acc and use cur as the neutral element - if (acc.size() == 0) { - return cur; - } - assert(cur.size() == acc.size()); - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newScan = parent.front(); - auto returnOp = dyn_cast(newScan.getTerminator()); - - SmallVector combineArgs(2 * acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; - } - - rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), - combineArgs); - SmallVector results; - llvm::transform(returnOp.getResult(), std::back_inserter(results), - [&](Value res) { return rewriter.getRemappedValue(res); }); - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); - return results; +static SmallVector accumulate(ScanLoweringHelper &helper, + ConversionPatternRewriter &rewriter, + ValueRange acc, ValueRange cur, + Value pred = {}) { + auto loc = helper.getLoc(); + auto &combineOp = helper.getCombineOp(); + return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); } // Scan a contiguous elements within a thread and update `srcValues` in place. @@ -66,8 +40,8 @@ scanThreadContiguousElements(SmallVector> &srcValues, unsigned accIndex = (srcIndex % stride) + ((srcIndex / stride) / scanElementsPerThreads) * stride; - accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], - srcValues[srcIndex]); + accs[accIndex] = + accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]); srcValues[srcIndex] = accs[accIndex]; } } @@ -95,11 +69,11 @@ static void warpScan(SmallVector> &srcValues, for (unsigned j = 0; j < acc.size(); ++j) { shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); } + Value mask = icmp_sge(laneIdAxis, i32_val(i)); SmallVector tempAcc = - accumulate(rewriter, helper.getCombineOp(), shfl, acc); - Value mask = icmp_slt(laneIdAxis, i32_val(i)); + accumulate(helper, rewriter, shfl, acc, mask); for (unsigned j = 0; j < acc.size(); ++j) { - acc[j] = select(mask, acc[j], tempAcc[j]); + acc[j] = select(mask, tempAcc[j], acc[j]); } } srcValues[srcIndex] = acc; @@ -164,9 +138,9 @@ static void AddPartialReduce(SmallVector> &srcValues, unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); - Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); - Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); - Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + Value maskNotFirstWarp = icmp_ne(warpId, i32_val(0)); + Value maskNotFirstLane = icmp_ne(laneIdAxis, i32_val(0)); + Value maskNotFirstThread = or_(maskNotFirstWarp, maskNotFirstLane); struct Accumulator { SmallVector acc; SmallVector maskedAcc; @@ -212,22 +186,24 @@ static void AddPartialReduce(SmallVector> &srcValues, accumulator.maskedAcc = partialReduce; continue; } - accumulator.acc = accumulate(rewriter, helper.getCombineOp(), - accumulator.acc, partialReduce); - Value mask = icmp_slt(warpId, i32_val(i + 1)); + Value mask = icmp_sge(warpId, i32_val(i + 1)); + accumulator.acc = + accumulate(helper, rewriter, accumulator.acc, partialReduce, mask); for (unsigned j = 0; j < helper.getNumOperands(); ++j) { accumulator.maskedAcc[j] = - select(mask, accumulator.maskedAcc[j], accumulator.acc[j]); + select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); } } - auto temp = accumulate(rewriter, helper.getCombineOp(), - accumulator.maskedAcc, srcValues[srcIndex]); + + Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{}; + auto temp = accumulate(helper, rewriter, accumulator.maskedAcc, + srcValues[srcIndex], pred); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. auto val = srcValues[srcIndex]; for (unsigned i = 0; i < helper.getNumOperands(); ++i) { - temp[i] = select(maskFirstWarp, val[i], temp[i]); + temp[i] = select(maskNotFirstWarp, temp[i], val[i]); } } srcValues[srcIndex] = temp; @@ -235,19 +211,18 @@ static void AddPartialReduce(SmallVector> &srcValues, SmallVector lastElement(helper.getNumOperands()); for (unsigned i = 0; i < helper.getNumOperands(); ++i) { auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); - lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem); + lastElement[i] = select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + pred = axisBlockId == 0 ? maskNotFirstThread : Value{}; auto laneValue = srcValues[srcIndex - i * elementStride]; - laneValue = - accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. for (unsigned j = 0; j < helper.getNumOperands(); ++j) { - laneValue[j] = - select(maskFirstThread, - srcValues[srcIndex - i * elementStride][j], laneValue[j]); + laneValue[j] = select(maskNotFirstThread, laneValue[j], + srcValues[srcIndex - i * elementStride][j]); } } srcValues[srcIndex - i * elementStride] = laneValue; @@ -300,8 +275,8 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, if (axisBlockId == 0) // First chunk and first block accumulator = srcValues[srcIndex]; else - srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), - accumulator, srcValues[srcIndex]); + srcValues[srcIndex] = + accumulate(helper, rewriter, accumulator, srcValues[srcIndex]); // Update the rest of the contiguous elements. auto lastElement = srcValues[srcIndex]; if (scanDim > 1) { @@ -319,8 +294,7 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { auto laneValue = srcValues[srcIndex - i * elementStride]; - laneValue = - accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + laneValue = accumulate(helper, rewriter, lastElement, laneValue); if (axisBlockId == 0) { for (unsigned j = 0; j < helper.getNumOperands(); ++j) { // For the first warp and first chunk we don't have anything to diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d84788d9d653..039f7ac1ac4f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5644,3 +5644,52 @@ def check_loop_unroll_count(ir, opStr, loop_unroll_factor): for unroll_factor in [1, 2, 4, 5, 8]: h = _kernel[(1, )](torch.empty(1, device=device), unroll_factor) check_loop_unroll_count(h.asm["ttir"], 'tt.atomic_rmw', unroll_factor) + + +@triton.jit +def sanitize_add(a, b): + a64 = a.to(tl.int64) + b64 = b.to(tl.int64) + r64 = a64 + b64 + tl.device_assert((r64 >= -2**31) & (r64 <= 2**31 - 1)) + return a + b + + +def test_side_effectful_reduction(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.reduce(vals, 0, sanitize_add) + tl.store(Z, z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros((), device="cuda", dtype=torch.int32) + sanitize_sum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.sum().to(torch.int32)) + + +def test_side_effectful_scan(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.associative_scan(vals, 0, sanitize_add) + tl.store(Z + tl.arange(0, BLOCK), z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros_like(X) + sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32))