diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td index 1d981b5b7742..17251c2a4ebe 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td @@ -99,8 +99,7 @@ def TTI_ExperimentalGSanTensorDescInfoOp def TTI_ExperimentalGSanTensorAccessOp : TTI_Op<"experimental_gsan_tensor_access", - [MemoryEffects<[MemWrite, MemWrite]>, - TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + [TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", "($_op.getOperands().size() <= 1) || std::equal_to<>()">]> { let summary = "Instrument a tensor load/store access for GSan"; @@ -108,13 +107,30 @@ def TTI_ExperimentalGSanTensorAccessOp Emits runtime instrumentation for a tensor pointer access. The pointer and optional mask are consumed by the GSan runtime. }]; - let arguments = (ins TT_PtrLike:$ptr, Optional:$mask, - BoolAttr:$isStore); + let arguments = (ins Arg, MemRead]>:$ptr, + Optional:$mask, BoolAttr:$isStore); let assemblyFormat = [{ $ptr `,` $isStore (`,` $mask^)? attr-dict `:` type($ptr) }]; } +def TTI_ExperimentalGSanAtomicTensorAccessOp + : TTI_Op<"experimental_gsan_atomic_tensor_access", + [TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">]> { + let summary = "Instrument a tensor atomic access for GSan"; + let description = [{ + Emits runtime instrumentation for a tensor pointer access whose individual + elements are atomic read-modify-write operations. + }]; + let arguments = (ins Arg, MemRead]>:$ptr, + Optional:$mask, TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr (`,` $mask^)? attr-dict `:` type($ptr) + }]; +} + def TTI_ExperimentalGSanAtomicRMWOp : TTI_Op<"experimental_gsan_atomic_rmw", [ SameOperandsAndResultShape, SameOperandsAndResultEncoding, diff --git a/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp index de8834b11527..14e49d7f0db6 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp @@ -7,6 +7,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/TritonInstrument/IR/Dialect.h" #include "llvm/ADT/SmallString.h" +#include "llvm/Support/LogicalResult.h" #include namespace tt = mlir::triton; @@ -19,6 +20,7 @@ static constexpr unsigned kTensorMapStrideWordBase = 3; static constexpr unsigned kTensorMapShapeWordBase = 8; static constexpr unsigned kTensorMapScalarWordBase = 2; static constexpr unsigned kTensorMapNumQwords = 16; +static constexpr unsigned kGSanShadowGranularityBytes = 4; struct GSanSourceLocation { Value file; @@ -29,6 +31,8 @@ static constexpr StringLiteral kGSanLoadTensorRuntimeFn = "__triton_gsan_load_tensor"; static constexpr StringLiteral kGSanStoreTensorRuntimeFn = "__triton_gsan_store_tensor"; +static constexpr StringLiteral kGSanAtomicTensorRuntimeFn = + "__triton_gsan_atomic_tensor"; static constexpr StringLiteral kGSanAtomicBeginRuntimeFn = "__triton_gsan_atomic_begin_scalar"; static constexpr StringLiteral kGSanAtomicEndRuntimeFn = @@ -51,6 +55,9 @@ getOrCreateGSanRuntimeFunction(ConversionPatternRewriter &rewriter, } else if (funcName == kGSanLoadTensorRuntimeFn || funcName == kGSanStoreTensorRuntimeFn) { argTys = {ptr_ty(ctx), ptr_ty(ctx), i32_ty, i32_ty, ptr_ty(ctx), i32_ty}; + } else if (funcName == kGSanAtomicTensorRuntimeFn) { + argTys = {ptr_ty(ctx), ptr_ty(ctx), i32_ty, i32_ty, + i32_ty, i32_ty, ptr_ty(ctx), i32_ty}; } else if (funcName == kGSanAtomicBeginRuntimeFn) { argTys = {ptr_ty(ctx), ptr_ty(ctx), i32_ty, i64_ty, i32_ty, i32_ty, i32_ty, ptr_ty(ctx), i32_ty}; @@ -106,15 +113,10 @@ materializeSourceLocation(ConversionPatternRewriter &rewriter, Location loc) { // Utility functions //////////////////////////////////////////// -void emitTensorAccessRuntimeCall(ConversionPatternRewriter &rewriter, - Location loc, Value gsanGlobalStatePtr, - ArrayRef ptrElems, - ArrayRef maskElems, uint32_t regMask, - Value threadPred, int32_t bytesPerElem, - bool isStore, unsigned elemIndexStride = 1) { - if (ptrElems.empty()) - return; - +Value prepareTensorStackArg(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef ptrElems, ArrayRef maskElems, + uint32_t regMask, Value threadPred, + unsigned elemIndexStride) { auto *ctx = rewriter.getContext(); TritonLLVMOpBuilder b(loc, rewriter); Value one = b.i32_val(1); @@ -147,20 +149,57 @@ void emitTensorAccessRuntimeCall(ConversionPatternRewriter &rewriter, b.store(maskByte, maskSlot); } + return argsBuffer; +} + +void emitTensorAccessRuntimeCall(ConversionPatternRewriter &rewriter, + Location loc, Value gsanGlobalStatePtr, + ArrayRef ptrElems, + ArrayRef maskElems, uint32_t regMask, + Value threadPred, int32_t bytesPerElem, + bool isStore, unsigned elemIndexStride = 1) { + if (ptrElems.empty()) + return; + + TritonLLVMOpBuilder b(loc, rewriter); + auto stackPtr = prepareTensorStackArg(rewriter, loc, ptrElems, maskElems, + regMask, threadPred, elemIndexStride); StringRef funcName = isStore ? kGSanStoreTensorRuntimeFn : kGSanLoadTensorRuntimeFn; auto runtimeFunc = getOrCreateGSanRuntimeFunction(rewriter, funcName); - if (gsanGlobalStatePtr.getType() != ptr_ty(ctx)) { - gsanGlobalStatePtr = b.addrspacecast(ptr_ty(ctx), gsanGlobalStatePtr); - } - Value argsPtr = b.bitcast(argsBuffer, ptr_ty(ctx)); auto sourceLoc = materializeSourceLocation(rewriter, loc); b.call(runtimeFunc, - ValueRange{gsanGlobalStatePtr, argsPtr, b.i32_val(numElems), + ValueRange{gsanGlobalStatePtr, stackPtr, b.i32_val(ptrElems.size()), b.i32_val(bytesPerElem), sourceLoc.file, sourceLoc.line}); } +void emitAtomicTensorAccessRuntimeCall(ConversionPatternRewriter &rewriter, + Location loc, Value gsanGlobalStatePtr, + ArrayRef ptrElems, + ArrayRef maskElems, + uint32_t regMask, Value threadPred, + int32_t bytesPerElem, MemSemantic sem, + MemSyncScope scope) { + if (ptrElems.empty()) + return; + + TritonLLVMOpBuilder b(loc, rewriter); + auto stackPtr = + prepareTensorStackArg(rewriter, loc, ptrElems, maskElems, regMask, + threadPred, /*elemIndexStride=*/1); + auto runtimeFunc = + getOrCreateGSanRuntimeFunction(rewriter, kGSanAtomicTensorRuntimeFn); + auto sourceLoc = materializeSourceLocation(rewriter, loc); + + b.call(runtimeFunc, + ValueRange{gsanGlobalStatePtr, stackPtr, b.i32_val(ptrElems.size()), + b.i32_val(bytesPerElem), + b.i32_val(static_cast(sem)), + b.i32_val(static_cast(scope)), sourceLoc.file, + sourceLoc.line}); +} + void createBarrier(ConversionPatternRewriter &rewriter, Location loc, int numCTAs, const TargetInfoBase &targetInfo) { auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -203,11 +242,8 @@ void emitGSanAtomicBeginCall(ConversionPatternRewriter &rewriter, Location loc, Value pred, Value ptr, int32_t bytesPerElem, int32_t sem, int32_t scope, GSanSourceLocation sourceLoc) { - auto *ctx = rewriter.getContext(); TritonLLVMOpBuilder b(loc, rewriter); - if (gsanGlobalStatePtr.getType() != ptr_ty(ctx)) - gsanGlobalStatePtr = b.addrspacecast(ptr_ty(ctx), gsanGlobalStatePtr); - Value statePtr = b.bitcast(eventStatePtr, ptr_ty(ctx)); + Value statePtr = b.bitcast(eventStatePtr, ptr_ty(rewriter.getContext())); auto runtimeFunc = getOrCreateGSanRuntimeFunction(rewriter, kGSanAtomicBeginRuntimeFn); b.call(runtimeFunc, @@ -232,6 +268,61 @@ void emitGSanAtomicEndCall(ConversionPatternRewriter &rewriter, Location loc, b.i32_val(scope), sourceLoc.file, sourceLoc.line}); } +template +unsigned getTensorAccessVecSize(OpT op, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool keepWithinSingleShadowCell) { + auto ptrTy = op.getPtr().getType(); + auto bytesPerElem = std::max(8u, tt::getPointeeBitWidth(ptrTy)) / 8; + auto contiguity = axisInfoAnalysis.getContiguity(op.getPtr()); + + if (keepWithinSingleShadowCell) { + if (bytesPerElem >= kGSanShadowGranularityBytes) + return 1; + contiguity = + std::min(contiguity, kGSanShadowGranularityBytes / bytesPerElem); + } + + if (!op.getMask()) + return contiguity; + + auto maskAlign = axisInfoAnalysis.getMaskAlignment(op.getMask()); + if (bytesPerElem < kGSanShadowGranularityBytes) { + maskAlign = std::max(maskAlign, kGSanShadowGranularityBytes / bytesPerElem); + } + return std::min(contiguity, maskAlign); +} + +void mergeTensorAccessElements(ConversionPatternRewriter &rewriter, + Location loc, SmallVector &ptrElems, + SmallVector &maskElems, unsigned mergeVec, + unsigned maskAlign, int32_t &bytesPerElem) { + if (mergeVec <= 1) + return; + + SmallVector mergedPtrElems; + SmallVector mergedMaskElems; + mergedPtrElems.reserve(ptrElems.size() / mergeVec); + if (!maskElems.empty()) + mergedMaskElems.reserve(ptrElems.size() / mergeVec); + + for (unsigned i = 0; i < ptrElems.size(); i += mergeVec) { + mergedPtrElems.push_back(ptrElems[i]); + if (maskElems.empty()) + continue; + Value mergedMask = maskElems[i]; + for (unsigned j = maskAlign; j < mergeVec; j += maskAlign) { + mergedMask = + arith::OrIOp::create(rewriter, loc, mergedMask, maskElems[i + j]); + } + mergedMaskElems.push_back(mergedMask); + } + + ptrElems = std::move(mergedPtrElems); + maskElems = std::move(mergedMaskElems); + bytesPerElem *= mergeVec; +} + Value bitcastToScalarInt(ConversionPatternRewriter &rewriter, Location loc, Value value) { Type ty = value.getType(); @@ -243,12 +334,21 @@ Value bitcastToScalarInt(ConversionPatternRewriter &rewriter, Location loc, return b.bitcast(value, intTy); } -Value getGSanGlobalStateArg(FunctionOpInterface funcOp) { +FailureOr getGSanGlobalStateArg(Operation *op, + ConversionPatternRewriter &rewriter, + Location loc) { + auto funcOp = op->getParentOfType(); for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) { - if (funcOp.getArgAttr(i, kGSanGlobalStateArgAttr)) - return funcOp.getArgument(i); + if (!funcOp.getArgAttr(i, kGSanGlobalStateArgAttr)) + continue; + Value arg = funcOp.getArgument(i); + if (arg.getType() == ptr_ty(rewriter.getContext())) + return arg; + TritonLLVMOpBuilder b(loc, rewriter); + arg = b.addrspacecast(ptr_ty(rewriter.getContext()), arg); + return arg; } - return {}; + return emitError(loc, "Unable to find gsan global state"); } static LLVM::LLVMStructType @@ -326,81 +426,99 @@ struct GSanTensorAccessOpConversion axisInfoAnalysis(&axisInfoAnalysis) {} unsigned getVecSize(tti::ExperimentalGSanTensorAccessOp op) const { - auto ptrTy = op.getPtr().getType(); - auto contiguity = axisInfoAnalysis->getContiguity(op.getPtr()); - if (!op.getMask()) { - return contiguity; - } - - auto maskAlign = axisInfoAnalysis->getMaskAlignment(op.getMask()); - auto bytesPerElem = tt::getPointeeBitWidth(ptrTy) / 8; - // Round up to at least shadow granularity, and we will or together the - // masks - if (bytesPerElem < 4) { - maskAlign = std::max(maskAlign, 4 / bytesPerElem); - } - return std::min(contiguity, maskAlign); + return getTensorAccessVecSize(op, *axisInfoAnalysis, + /*keepWithinSingleShadowCell=*/false); } LogicalResult matchAndRewrite(tti::ExperimentalGSanTensorAccessOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto ctx = getContext(); Location loc = op.getLoc(); - auto func = op->getParentOfType(); - Value gsanGlobalStatePtr = getGSanGlobalStateArg(func); - if (!gsanGlobalStatePtr) - return emitError(op.getLoc(), "Failed to find pointer to gsan state"); - - Value llPtr = adaptor.getPtr(); auto ptrTy = op.getPtr().getType(); - unsigned numElems = ttg::getTotalElemsPerThread(ptrTy); - auto bytesPerElem = tt::getPointeeBitWidth(ptrTy) / 8; - SmallVector ptrElems = unpackLLElements(loc, llPtr, rewriter); - assert(ptrElems.size() == numElems && - "Expected pointer element count to match layout"); - + int32_t bytesPerElem = tt::getPointeeBitWidth(ptrTy) / 8; + auto ptrElems = unpackLLElements(loc, adaptor.getPtr(), rewriter); SmallVector maskElems; if (Value llMask = adaptor.getMask()) { maskElems = unpackLLElements(loc, llMask, rewriter); - assert(maskElems.size() == numElems && - "Expected mask element count to match layout"); } unsigned mergeVec = getVecSize(op); - if (mergeVec > 1) { - auto maskAlign = - op.getMask() ? axisInfoAnalysis->getMaskAlignment(op.getMask()) : 1; - SmallVector mergedPtrElems; - SmallVector mergedMaskElems; - mergedPtrElems.reserve(numElems / mergeVec); - if (!maskElems.empty()) - mergedMaskElems.reserve(numElems / mergeVec); - - for (unsigned i = 0; i < numElems; i += mergeVec) { - mergedPtrElems.push_back(ptrElems[i]); - if (maskElems.empty()) - continue; - Value mergedMask = maskElems[i]; - for (unsigned j = maskAlign; j < mergeVec; j += maskAlign) { - mergedMask = - arith::OrIOp::create(rewriter, loc, mergedMask, maskElems[i + j]); - } - mergedMaskElems.push_back(mergedMask); - } + auto maskAlign = + op.getMask() ? axisInfoAnalysis->getMaskAlignment(op.getMask()) : 1; + mergeTensorAccessElements(rewriter, loc, ptrElems, maskElems, mergeVec, + maskAlign, bytesPerElem); + + auto ctx = op.getContext(); + auto kReg = str_attr("reg"); + auto freeVarMasks = getFreeVariableMasks(ptrTy); + auto threadPred = ttg::emitRedundantThreadPredicate(freeVarMasks, rewriter, + loc, *targetInfo); + auto gsanGlobalStatePtr = getGSanGlobalStateArg(op, rewriter, loc); + if (failed(gsanGlobalStatePtr)) + return failure(); + emitTensorAccessRuntimeCall(rewriter, loc, *gsanGlobalStatePtr, ptrElems, + maskElems, freeVarMasks.lookup(kReg), + threadPred, bytesPerElem, op.getIsStore(), + mergeVec); + + rewriter.eraseOp(op); + return success(); + } +}; - ptrElems = std::move(mergedPtrElems); - maskElems = std::move(mergedMaskElems); - bytesPerElem *= mergeVec; +struct GSanAtomicTensorAccessOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + tti::ExperimentalGSanAtomicTensorAccessOp>::ConvertOpToLLVMPattern; + const TargetInfoBase *targetInfo; + ModuleAxisInfoAnalysis *axisInfoAnalysis; + + GSanAtomicTensorAccessOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(&targetInfo), + axisInfoAnalysis(&axisInfoAnalysis) {} + + unsigned getVecSize(tti::ExperimentalGSanAtomicTensorAccessOp op) const { + // GSan tracks conflicts at shadow-cell granularity, so atomics may only be + // coalesced while they still fit inside a single shadow cell. + return getTensorAccessVecSize(op, *axisInfoAnalysis, + /*keepWithinSingleShadowCell=*/true); + } + + LogicalResult + matchAndRewrite(tti::ExperimentalGSanAtomicTensorAccessOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto ptrTy = op.getPtr().getType(); + auto ptrElems = unpackLLElements(loc, adaptor.getPtr(), rewriter); + SmallVector maskElems; + if (Value llMask = adaptor.getMask()) { + maskElems = unpackLLElements(loc, llMask, rewriter); } + int32_t bytesPerElem = std::max(1u, tt::getPointeeBitWidth(ptrTy) / 8); + unsigned mergeVec = getVecSize(op); + auto maskAlign = + op.getMask() ? axisInfoAnalysis->getMaskAlignment(op.getMask()) : 1; + mergeTensorAccessElements(rewriter, loc, ptrElems, maskElems, mergeVec, + maskAlign, bytesPerElem); + + auto ctx = op.getContext(); + auto kReg = str_attr("reg"); auto freeVarMasks = getFreeVariableMasks(ptrTy); - uint32_t regMask = freeVarMasks.lookup(str_attr("reg")); - Value threadPred = ttg::emitRedundantThreadPredicate(freeVarMasks, rewriter, - loc, *targetInfo); - emitTensorAccessRuntimeCall(rewriter, loc, gsanGlobalStatePtr, ptrElems, - maskElems, regMask, threadPred, bytesPerElem, - op.getIsStore(), mergeVec); + auto regMask = freeVarMasks.lookup(kReg); + auto threadPred = ttg::emitRedundantThreadPredicate(freeVarMasks, rewriter, + loc, *targetInfo); + auto gsanGlobalStatePtr = getGSanGlobalStateArg(op, rewriter, loc); + if (failed(gsanGlobalStatePtr)) + return failure(); + emitAtomicTensorAccessRuntimeCall(rewriter, loc, *gsanGlobalStatePtr, + ptrElems, maskElems, regMask, threadPred, + bytesPerElem, op.getSem(), op.getScope()); rewriter.eraseOp(op); return success(); @@ -425,10 +543,9 @@ struct GSanAtomicRMWOpConversion ConversionPatternRewriter &rewriter) const override { auto *ctx = rewriter.getContext(); Location loc = op.getLoc(); - auto func = op->getParentOfType(); - Value gsanGlobalStatePtr = getGSanGlobalStateArg(func); - if (!gsanGlobalStatePtr) - return emitError(op.getLoc(), "Failed to find pointer to gsan state"); + auto gsanGlobalStatePtr = getGSanGlobalStateArg(op, rewriter, loc); + if (failed(gsanGlobalStatePtr)) + return failure(); auto moduleOp = op->getParentOfType(); assert(moduleOp && "Parent ModuleOp not found for atomic op"); @@ -478,7 +595,7 @@ struct GSanAtomicRMWOpConversion Value rmwPtr = ptrElements[i]; Value rmwVal = valElements[i]; - emitGSanAtomicBeginCall(rewriter, loc, gsanGlobalStatePtr, eventState, + emitGSanAtomicBeginCall(rewriter, loc, *gsanGlobalStatePtr, eventState, pred, rmwPtr, bytesPerElem, static_cast(sem), static_cast(scope), sourceLoc); @@ -531,10 +648,9 @@ struct GSanAtomicCASOpConversion ConversionPatternRewriter &rewriter) const override { auto *ctx = rewriter.getContext(); Location loc = op.getLoc(); - auto func = op->getParentOfType(); - Value gsanGlobalStatePtr = getGSanGlobalStateArg(func); - if (!gsanGlobalStatePtr) - return emitError(op.getLoc(), "Failed to find pointer to gsan state"); + auto gsanGlobalStatePtr = getGSanGlobalStateArg(op, rewriter, loc); + if (failed(gsanGlobalStatePtr)) + return failure(); auto moduleOp = op->getParentOfType(); assert(moduleOp && "Parent ModuleOp not found for atomic op"); @@ -580,7 +696,7 @@ struct GSanAtomicCASOpConversion Value casCmp = cmpElements[i]; Value casVal = valElements[i]; - emitGSanAtomicBeginCall(rewriter, loc, gsanGlobalStatePtr, eventState, + emitGSanAtomicBeginCall(rewriter, loc, *gsanGlobalStatePtr, eventState, pred, casPtr, bytesPerElem, static_cast(sem), static_cast(scope), sourceLoc); @@ -671,23 +787,18 @@ struct GSanInitOpConversion matchAndRewrite(tti::ExperimentalGSanInitOp op, [[maybe_unused]] OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto func = op->getParentOfType(); - Value gsanGlobalStatePtr = getGSanGlobalStateArg(func); - if (!gsanGlobalStatePtr) - return emitError(op.getLoc(), "Failed to find pointer to gsan state"); + auto loc = op.getLoc(); + auto gsanGlobalStatePtr = getGSanGlobalStateArg(op, rewriter, loc); + if (failed(gsanGlobalStatePtr)) + return failure(); auto runtimeFunc = getOrCreateGSanRuntimeFunction(rewriter, kGSanInitRuntimeFn); - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); TritonLLVMOpBuilder b(loc, rewriter); - if (gsanGlobalStatePtr.getType() != ptr_ty(ctx)) { - gsanGlobalStatePtr = b.addrspacecast(ptr_ty(ctx), gsanGlobalStatePtr); - } auto sourceLoc = materializeSourceLocation(rewriter, loc); b.call(runtimeFunc, - ValueRange{gsanGlobalStatePtr, sourceLoc.file, sourceLoc.line}); + ValueRange{*gsanGlobalStatePtr, sourceLoc.file, sourceLoc.line}); b.barrier(ttg::AddrSpace::Local); rewriter.eraseOp(op); return success(); @@ -704,6 +815,8 @@ void mlir::triton::populateGSanToLLVMPatterns( patterns.add(typeConverter); patterns.add(typeConverter, targetInfo); patterns.add(typeConverter, targetInfo); + patterns.add( + typeConverter, axisInfoAnalysis, targetInfo); patterns.add(typeConverter, axisInfoAnalysis, targetInfo); } diff --git a/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp index 6dbbd5cbf35f..c1723a2c7378 100644 --- a/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp @@ -315,6 +315,19 @@ static void instrumentAsyncTMAStore(Operation *op, Value descValue, access.second, /*isStore=*/true); } +static void instrumentAsyncTMAReduce(ttng::AsyncTMAReduceOp op) { + OpBuilder builder(op); + auto desc = getDescriptorInfo(op.getDesc(), builder); + + auto offsets = castToI64(builder, op.getLoc(), op.getCoord()); + auto access = createTiledAccess(builder, op.getLoc(), desc, + op.getSrc().getType().getShape(), offsets, + std::nullopt); + ExperimentalGSanAtomicTensorAccessOp::create( + builder, op.getLoc(), access.first, access.second, MemSemantic::RELAXED, + MemSyncScope::GPU); +} + static void instrumentAsyncTMAGather(ttng::AsyncTMAGatherOp op) { OpBuilder builder(op); auto desc = getDescriptorInfo(op.getDesc(), builder); @@ -427,12 +440,8 @@ class GlobalSanitizerPass op.getSrc().getType().getShape(), op.getCoord()); }) - .Case([&](ttng::AsyncTMAReduceOp op) { - // FIXME: This is just plain wrong. TMA reduce is atomic. - instrumentAsyncTMAStore(op, op.getDesc(), - op.getSrc().getType().getShape(), - op.getCoord()); - }) + .Case( + [&](ttng::AsyncTMAReduceOp op) { instrumentAsyncTMAReduce(op); }) .Case([&](ttng::AsyncTMAScatterOp op) { instrumentAsyncTMAScatter(op); }) diff --git a/python/test/gsan/test_gsan.py b/python/test/gsan/test_gsan.py index 9b143491dc99..63eea44e2be2 100644 --- a/python/test/gsan/test_gsan.py +++ b/python/test/gsan/test_gsan.py @@ -391,6 +391,15 @@ def _host_tma_scatter_kernel(desc, x_offsets_ptr, y_offset, src_ptr, src_stride_ desc.scatter(src, x_offsets, y_offset) +@triton.jit +def _host_tma_reduce_add_kernel(desc, src_ptr, src_stride_0, src_stride_1, BLOCK_X: tl.constexpr): + BLOCK_Y: tl.constexpr = desc.block_shape[1] + indices_x = tl.arange(0, BLOCK_X)[:, None] * src_stride_0 + indices_y = tl.arange(0, BLOCK_Y)[None, :] * src_stride_1 + src = tl.load(src_ptr + indices_x + indices_y) + desc.atomic_add([0, 0], src) + + def _shadow_cell_state(cell) -> tuple[int, object, tuple[object, ...]]: return (cell.num_reads, cell.write_clock, tuple(cell.read_clocks)) @@ -586,3 +595,19 @@ def test_host_tma_scatter_updates_shadow(with_gsan): shadow1 = _shadow_cells_for_tensor(target_storage) _assert_shadow_mask(shadow0, shadow1, changed_mask, access_kind="write") assert target_storage[m_size, y_offset].item() == 0 + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires Hopper or newer") +def test_host_tma_reduce_updates_atomic_shadow(with_gsan): + block_x = 1 + block_y = 16 + target = torch.zeros((block_x, block_y), dtype=torch.int32, device="cuda") + src = torch.arange(1, block_y + 1, dtype=torch.int32, device="cuda").reshape(block_x, block_y) + target_desc = TensorDescriptor.from_tensor(target, [block_x, block_y]) + + compiled = _host_tma_reduce_add_kernel[(1, )](target_desc, src, src.stride(0), src.stride(1), BLOCK_X=block_x) + assert "ttng.async_tma_reduce" in compiled.asm["ttgir"] + torch.cuda.synchronize() + + torch.testing.assert_close(target, src) + _assert_atomic_rmw_shadow(target[0, 0].data_ptr(), AtomicScope.GPU, is_release=False) diff --git a/python/test/gsan/test_gsan_failures.py b/python/test/gsan/test_gsan_failures.py index e40e389f5425..eb5420a6b9cc 100644 --- a/python/test/gsan/test_gsan_failures.py +++ b/python/test/gsan/test_gsan_failures.py @@ -175,6 +175,25 @@ def _host_tma_scatter_war_kernel(target_ptr, target_desc, x_offsets_ptr, src_ptr target_desc.scatter(values, x_offsets, y_offset) +@triton.jit +def _host_tma_atomic_flag_publish_kernel(payload_ptr, flag_ptr, flag_desc, counter_ptr, scratch_ptr): + pid = tl.program_id(0) + if pid == 0: + tl.store(payload_ptr, 1000) + tl.atomic_xchg(flag_ptr, 1, sem="release", scope="gpu") + tl.atomic_add(counter_ptr, 1, sem="relaxed") + else: + atomic_poll(counter_ptr, 1) + BLOCK_X: tl.constexpr = flag_desc.block_shape[0] + BLOCK_Y: tl.constexpr = flag_desc.block_shape[1] + values = tl.full((BLOCK_X, BLOCK_Y), 1, dtype=tl.int32) + # TMA atomics on the released flag are relaxed.gpu and must not acquire + # the producer's prior payload store. + flag_desc.atomic_add([0, 0], values) + result = tl.load(payload_ptr) + tl.store(scratch_ptr, result) + + def _cuda_byte_allocator(size: int, _align: int, _stream): return torch.empty(size, dtype=torch.int8, device="cuda") @@ -296,6 +315,16 @@ def _run_host_tma_scatter_war_case() -> None: counter, row_idx, y_offset, target.stride(0), BLOCK_X=block_x) +@run_with_gsan +def _run_host_tma_atomic_flag_publish_case() -> None: + flag = torch.zeros((1, 16), dtype=torch.int32, device="cuda") + flag_desc = TensorDescriptor.from_tensor(flag, [1, 16]) + payload = torch.zeros(1, dtype=torch.int32, device="cuda") + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + scratch = torch.full((1, ), -1, dtype=torch.int32, device="cuda") + _host_tma_atomic_flag_publish_kernel[(2, )](payload, flag, flag_desc, counter, scratch, num_warps=1) + + @run_with_gsan def _run_cross_sm_atomic_sync_case(producer_sem: str, consumer_sem: str, scope: str) -> None: payload = torch.zeros(1, dtype=torch.int32, device="cuda") @@ -403,6 +432,13 @@ def test_host_tma_scatter_write_after_read(): marker="target_desc.scatter(values, x_offsets, y_offset)", error="Write after read race detected") +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires Hopper or newer") +def test_host_tma_atomic_on_release_flag_does_not_publish_data(): + _run_failure_case("host_tma_atomic_flag_publish", runner=_run_host_tma_atomic_flag_publish_case, + source_function=_host_tma_atomic_flag_publish_kernel.fn, marker="result = tl.load(payload_ptr)", + error="Read after write race detected") + + @pytest.mark.parametrize("producer_sem, consumer_sem, scope", CROSS_SM_SEMANTIC_MISMATCH_CASES) def test_cross_sm_semantic_mismatch_read_after_write(producer_sem, consumer_sem, scope): _run_failure_case(f"cross_sm_semantic_mismatch_{producer_sem}_{consumer_sem}_{scope}", diff --git a/python/triton/experimental/gsan/src/GSanLibrary.cu b/python/triton/experimental/gsan/src/GSanLibrary.cu index 52e2f422cd82..bedd288822a6 100644 --- a/python/triton/experimental/gsan/src/GSanLibrary.cu +++ b/python/triton/experimental/gsan/src/GSanLibrary.cu @@ -656,6 +656,26 @@ __triton_gsan_store_tensor(void *globalState, const char *stackPtr, gsan::tensorStore(threadState, stackPtr, numElems, bytesPerElem, loc); } +extern "C" GSAN_DEVICE void +__triton_gsan_atomic_tensor(void *globalState, const char *stackPtr, + int numElems, int bytesPerElem, int sem, int scope, + const char *file, unsigned line) { + auto loc = gsan::Location{file, line}; + auto *globals = reinterpret_cast(globalState); + const auto *ptrsPtr = reinterpret_cast(stackPtr); + const auto *maskPtr = stackPtr + numElems * sizeof(gsan::uintptr_t); + + for (int i = 0; i < numElems; ++i) { + if (!maskPtr[i]) + continue; + gsan::AtomicEventState event; + gsan::beginAtomicAccess(globals, &event, /*pred=*/true, ptrsPtr[i], + bytesPerElem, sem, scope, loc); + gsan::endAtomicAccess(&event, /*pred=*/true, /*didWrite=*/true, sem, scope, + loc); + } +} + extern "C" GSAN_DEVICE void __triton_gsan_atomic_begin_scalar( void *globalState, void *eventState, int pred, gsan::uintptr_t address, int bytesPerElem, int sem, int scope, const char *file, unsigned line) {