diff --git a/include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt b/include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt index ffd4ec1dcd02..981af117551e 100644 --- a/include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt @@ -5,11 +5,13 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=tti) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=tti) add_mlir_doc(TritonInstrumentDialect TritonInstrumentDialect dialects/ -gen-dialect-doc -dialect=tti) +set(LLVM_TARGET_DEFINITIONS TritonInstrumentAttrDefs.td) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) + set(LLVM_TARGET_DEFINITIONS TritonInstrumentOps.td) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) add_mlir_doc(TritonInstrumentOps TritonInstrumentOps dialects/ -gen-op-doc -dialect=tti) add_public_tablegen_target(TritonInstrumentTableGen) diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td index 92d91b52071e..1d981b5b7742 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td @@ -4,6 +4,8 @@ include "triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td" @@ -56,11 +58,6 @@ def TTI_ExperimentalMemDescToI32Op : TTI_Op<"experimental_memdesc_to_i32", [Pure }]; let arguments = (ins TTG_MemDescType:$memdesc); let results = (outs I32:$result); - let builders = [ - OpBuilder<(ins "Value":$memdesc), [{ - build($_builder, $_state, $_builder.getI32Type(), memdesc); - }]> - ]; let assemblyFormat = "$memdesc attr-dict `:` type($memdesc)"; } @@ -118,6 +115,55 @@ def TTI_ExperimentalGSanTensorAccessOp }]; } +def TTI_ExperimentalGSanAtomicRMWOp : TTI_Op<"experimental_gsan_atomic_rmw", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"mask type matches value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Lower a GSan-instrumented atomic rmw"; + let arguments = (ins + TT_AtomicRMWAttr:$atomic_rmw_op, + Arg, MemWrite]>:$ptr, + TT_Type:$val, + Optional:$mask, + TT_MemSemanticAttr:$sem, + TT_MemSyncScopeAttr:$scope + ); + let results = (outs TT_Type:$result); + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def TTI_ExperimentalGSanAtomicCASOp : TTI_Op<"experimental_gsan_atomic_cas", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"ptr type matches cmp type", "cmp", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)"> +]> { + let summary = "Lower a GSan-instrumented atomic cas"; + let arguments = (ins + Arg, MemWrite]>:$ptr, + TT_Type:$cmp, + TT_Type:$val, + TT_MemSemanticAttr:$sem, + TT_MemSyncScopeAttr:$scope + ); + let results = (outs TT_Type:$result); + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + + // ===== Critical section lock ops ===== diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 347cedbe8e8a..9dea1c928d46 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -10,6 +10,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/GenericSwizzling.h" #include "triton/Tools/LayoutUtils.h" @@ -30,6 +31,7 @@ static size_t getPartitionIndex(size_t offset, size_t partitionSize) { } namespace ttng = mlir::triton::nvidia_gpu; +namespace tti = mlir::triton::instrument; namespace mlir { @@ -115,7 +117,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy); return elems * getBitwidth(srcTy) / 8; } - if (isa(op)) { + if (isa(op)) { auto value = op->getOperand(0); auto smemShape = getRepShapeForAtomic(op->getResult(0)); auto elems = getNumScratchElements(smemShape); diff --git a/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp index c77b0fd016fe..21d9c77f1980 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp @@ -1,5 +1,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" +#include "third_party/nvidia/include/TritonNVIDIAGPUToLLVM/AtomicPTXBuilder.h" #include "third_party/nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -27,6 +29,10 @@ static constexpr StringLiteral kGSanLoadTensorRuntimeFn = "__triton_gsan_load_tensor"; static constexpr StringLiteral kGSanStoreTensorRuntimeFn = "__triton_gsan_store_tensor"; +static constexpr StringLiteral kGSanAtomicBeginRuntimeFn = + "__triton_gsan_atomic_begin_scalar"; +static constexpr StringLiteral kGSanAtomicEndRuntimeFn = + "__triton_gsan_atomic_end_scalar"; static constexpr StringLiteral kGSanInitRuntimeFn = "__triton_gsan_init"; static constexpr StringLiteral kGSanGlobalStateArgAttr = "tti.gsan_global_state"; @@ -42,8 +48,16 @@ getOrCreateGSanRuntimeFunction(ConversionPatternRewriter &rewriter, SmallVector argTys; if (funcName == kGSanInitRuntimeFn) { argTys = {ptr_ty(ctx), ptr_ty(ctx), i32_ty}; - } else { + } else if (funcName == kGSanLoadTensorRuntimeFn || + funcName == kGSanStoreTensorRuntimeFn) { argTys = {ptr_ty(ctx), ptr_ty(ctx), i32_ty, i32_ty, ptr_ty(ctx), i32_ty}; + } else if (funcName == kGSanAtomicBeginRuntimeFn) { + argTys = {ptr_ty(ctx), ptr_ty(ctx), i32_ty, i64_ty, i32_ty, + i32_ty, i32_ty, ptr_ty(ctx), i32_ty}; + } else if (funcName == kGSanAtomicEndRuntimeFn) { + argTys = {ptr_ty(ctx), i32_ty, i32_ty, i32_ty, i32_ty, ptr_ty(ctx), i32_ty}; + } else { + llvm_unreachable("unexpected GSan runtime symbol"); } auto funcTy = LLVM::LLVMFunctionType::get(void_ty(ctx), argTys); RewriterBase::InsertionGuard guard(rewriter); @@ -52,6 +66,13 @@ getOrCreateGSanRuntimeFunction(ConversionPatternRewriter &rewriter, funcTy); } +LLVM::LLVMStructType +getGSanAtomicEventStateType(ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + return LLVM::LLVMStructType::getLiteral( + ctx, {ptr_ty(ctx), array_ty(ptr_ty(ctx), 3), i8_ty}); +} + FileLineColLoc extractSourceLocation(Location loc) { if (auto fileLoc = dyn_cast(loc)) return fileLoc; @@ -140,6 +161,88 @@ void emitTensorAccessRuntimeCall(ConversionPatternRewriter &rewriter, b.i32_val(bytesPerElem), sourceLoc.file, sourceLoc.line}); } +void createBarrier(ConversionPatternRewriter &rewriter, Location loc, + int numCTAs, const TargetInfoBase &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (numCTAs == 1) { + b.barrier(ttg::AddrSpace::Local); + } else { + targetInfo.clusterBarrier(loc, rewriter); + } +} + +unsigned getCanonicalIndex(unsigned index, unsigned freeVarMask) { + return index & ~freeVarMask; +} + +Value broadcastScalarAtomicResult(Operation *op, Type valueElemTy, + Value resultVal, + ConversionPatternRewriter &rewriter, + TritonLLVMOpBuilder &b, Value threadPred, + const TargetInfoBase &targetInfo) { + if (!op->hasAttr("allocation.offset")) + return resultVal; + + auto loc = op->getLoc(); + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + targetInfo.storeShared(rewriter, loc, smemBase, resultVal, threadPred); + b.barrier(ttg::AddrSpace::Local); + return targetInfo.loadShared(rewriter, loc, smemBase, valueElemTy, + b.true_val()); +} + +Value materializeI32Bool(ConversionPatternRewriter &rewriter, + TritonLLVMOpBuilder &b, Value pred) { + if (!pred) + return b.i32_val(1); + return b.zext(i32_ty, pred); +} + +void emitGSanAtomicBeginCall(ConversionPatternRewriter &rewriter, Location loc, + Value gsanGlobalStatePtr, Value eventStatePtr, + Value pred, Value ptr, int32_t bytesPerElem, + int32_t sem, int32_t scope, + GSanSourceLocation sourceLoc) { + auto *ctx = rewriter.getContext(); + TritonLLVMOpBuilder b(loc, rewriter); + if (gsanGlobalStatePtr.getType() != ptr_ty(ctx)) + gsanGlobalStatePtr = b.addrspacecast(ptr_ty(ctx), gsanGlobalStatePtr); + Value statePtr = b.bitcast(eventStatePtr, ptr_ty(ctx)); + auto runtimeFunc = + getOrCreateGSanRuntimeFunction(rewriter, kGSanAtomicBeginRuntimeFn); + b.call(runtimeFunc, + ValueRange{gsanGlobalStatePtr, statePtr, + materializeI32Bool(rewriter, b, pred), + b.ptrtoint(i64_ty, ptr), b.i32_val(bytesPerElem), + b.i32_val(sem), b.i32_val(scope), sourceLoc.file, + sourceLoc.line}); +} + +void emitGSanAtomicEndCall(ConversionPatternRewriter &rewriter, Location loc, + Value eventStatePtr, Value pred, Value didWrite, + int32_t sem, int32_t scope, + GSanSourceLocation sourceLoc) { + TritonLLVMOpBuilder b(loc, rewriter); + auto runtimeFunc = + getOrCreateGSanRuntimeFunction(rewriter, kGSanAtomicEndRuntimeFn); + Value statePtr = b.bitcast(eventStatePtr, ptr_ty(rewriter.getContext())); + b.call(runtimeFunc, + ValueRange{statePtr, materializeI32Bool(rewriter, b, pred), + materializeI32Bool(rewriter, b, didWrite), b.i32_val(sem), + b.i32_val(scope), sourceLoc.file, sourceLoc.line}); +} + +Value bitcastToScalarInt(ConversionPatternRewriter &rewriter, Location loc, + Value value) { + Type ty = value.getType(); + if (ty.isInteger()) + return value; + auto intTy = + IntegerType::get(rewriter.getContext(), ty.getIntOrFloatBitWidth()); + TritonLLVMOpBuilder b(loc, rewriter); + return b.bitcast(value, intTy); +} + Value getGSanGlobalStateArg(FunctionOpInterface funcOp) { for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) { if (funcOp.getArgAttr(i, kGSanGlobalStateArgAttr)) @@ -304,6 +407,216 @@ struct GSanTensorAccessOpConversion } }; +struct GSanAtomicRMWOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + tti::ExperimentalGSanAtomicRMWOp>::ConvertOpToLLVMPattern; + const TargetInfoBase *targetInfo; + + GSanAtomicRMWOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(&targetInfo) {} + + LogicalResult + matchAndRewrite(tti::ExperimentalGSanAtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + auto func = op->getParentOfType(); + Value gsanGlobalStatePtr = getGSanGlobalStateArg(func); + if (!gsanGlobalStatePtr) + return emitError(op.getLoc(), "Failed to find pointer to gsan state"); + + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for atomic op"); + auto rmwOp = op.getAtomicRmwOp(); + auto sem = op.getSem(); + auto scope = op.getScope(); + + TritonLLVMOpBuilder b(loc, rewriter); + Value llPtr = adaptor.getPtr(); + Value llVal = adaptor.getVal(); + Value llMask = adaptor.getMask(); + + auto ptrElements = unpackLLElements(loc, llPtr, rewriter); + auto valElements = unpackLLElements(loc, llVal, rewriter); + SmallVector maskElements; + if (llMask) + maskElements = unpackLLElements(loc, llMask, rewriter); + + auto valueTy = op.getType(); + auto tensorTy = dyn_cast(valueTy); + Type valueElemTy = valElements[0].getType(); + unsigned valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + int32_t bytesPerElem = std::max(1, valueElemNBits / 8); + auto elemsPerThread = ttg::getTotalElemsPerThread(op.getVal().getType()); + auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType()); + Value threadPred = ttg::emitRedundantThreadPredicate(freeVarMasks, rewriter, + loc, *targetInfo); + uint32_t regMask = freeVarMasks.lookup(str_attr("reg")); + auto sourceLoc = materializeSourceLocation(rewriter, loc); + auto eventStateTy = getGSanAtomicEventStateType(rewriter); + Value eventState = LLVM::AllocaOp::create(rewriter, loc, ptr_ty(ctx), + eventStateTy, b.i32_val(1), + /*alignment=*/0); + + SmallVector resultVals(elemsPerThread); + + for (size_t i = 0; i < elemsPerThread; ++i) { + if (auto canonicalIdx = getCanonicalIndex(i, regMask); + i != canonicalIdx) { + resultVals[i] = resultVals[canonicalIdx]; + continue; + } + + Value pred = + llMask ? ttg::maybeAnd(rewriter, loc, threadPred, maskElements[i]) + : threadPred; + Value rmwPtr = ptrElements[i]; + Value rmwVal = valElements[i]; + + emitGSanAtomicBeginCall(rewriter, loc, gsanGlobalStatePtr, eventState, + pred, rmwPtr, bytesPerElem, + static_cast(sem), + static_cast(scope), sourceLoc); + + SmallVector rmwVals{rmwVal}; + auto old = NVIDIA::emitPtxAtomicRMW(rewriter, loc, valueElemTy, rmwPtr, + rmwVals, rmwOp, sem, scope, pred); + if (failed(old)) + return failure(); + + emitGSanAtomicEndCall(rewriter, loc, eventState, pred, pred, + static_cast(sem), + static_cast(scope), sourceLoc); + resultVals[i] = *old; + } + + if (op.getResult().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + + if (!tensorTy) { + Value scalarResult = broadcastScalarAtomicResult( + op, valueElemTy, resultVals[0], rewriter, b, threadPred, *targetInfo); + rewriter.replaceOp(op, {scalarResult}); + return success(); + } + + finalizeTensorAtomicResults(op, tensorTy, rewriter, resultVals, valueElemTy, + b, threadPred, *targetInfo, getTypeConverter()); + return success(); + } +}; + +struct GSanAtomicCASOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + tti::ExperimentalGSanAtomicCASOp>::ConvertOpToLLVMPattern; + const TargetInfoBase *targetInfo; + + GSanAtomicCASOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(&targetInfo) {} + + LogicalResult + matchAndRewrite(tti::ExperimentalGSanAtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + auto func = op->getParentOfType(); + Value gsanGlobalStatePtr = getGSanGlobalStateArg(func); + if (!gsanGlobalStatePtr) + return emitError(op.getLoc(), "Failed to find pointer to gsan state"); + + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for atomic op"); + auto sem = op.getSem(); + auto scope = op.getScope(); + + TritonLLVMOpBuilder b(loc, rewriter); + Value llPtr = adaptor.getPtr(); + Value llCmp = adaptor.getCmp(); + Value llVal = adaptor.getVal(); + + auto ptrElements = unpackLLElements(loc, llPtr, rewriter); + auto cmpElements = unpackLLElements(loc, llCmp, rewriter); + auto valElements = unpackLLElements(loc, llVal, rewriter); + + auto valueTy = op.getType(); + auto tensorTy = dyn_cast(valueTy); + Type valueElemTy = valElements[0].getType(); + unsigned valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + int32_t bytesPerElem = valueElemNBits / 8; + auto elemsPerThread = ttg::getTotalElemsPerThread(op.getVal().getType()); + auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType()); + Value threadPred = ttg::emitRedundantThreadPredicate(freeVarMasks, rewriter, + loc, *targetInfo); + uint32_t regMask = freeVarMasks.lookup(str_attr("reg")); + auto sourceLoc = materializeSourceLocation(rewriter, loc); + auto eventStateTy = getGSanAtomicEventStateType(rewriter); + Value eventState = LLVM::AllocaOp::create(rewriter, loc, ptr_ty(ctx), + eventStateTy, b.i32_val(1), + /*alignment=*/0); + + SmallVector resultVals(elemsPerThread); + + for (size_t i = 0; i < elemsPerThread; ++i) { + if (auto canonicalIdx = getCanonicalIndex(i, regMask); + canonicalIdx != i) { + resultVals[i] = resultVals[canonicalIdx]; + continue; + } + + Value pred = threadPred; + Value casPtr = ptrElements[i]; + Value casCmp = cmpElements[i]; + Value casVal = valElements[i]; + + emitGSanAtomicBeginCall(rewriter, loc, gsanGlobalStatePtr, eventState, + pred, casPtr, bytesPerElem, + static_cast(sem), + static_cast(scope), sourceLoc); + + Value old = NVIDIA::emitPtxAtomicCAS(rewriter, loc, valueElemTy, casPtr, + casCmp, casVal, sem, scope, pred); + + auto oldInt = bitcastToScalarInt(rewriter, loc, old); + auto cmpInt = bitcastToScalarInt(rewriter, loc, casCmp); + Value didWrite = LLVM::ICmpOp::create( + rewriter, loc, i1_ty, LLVM::ICmpPredicate::eq, oldInt, cmpInt); + didWrite = ttg::maybeAnd(rewriter, loc, pred, didWrite); + emitGSanAtomicEndCall(rewriter, loc, eventState, pred, didWrite, + static_cast(sem), + static_cast(scope), sourceLoc); + resultVals[i] = old; + } + + if (op.getResult().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + + if (!tensorTy) { + Value scalarResult = broadcastScalarAtomicResult( + op, valueElemTy, resultVals[0], rewriter, b, threadPred, *targetInfo); + rewriter.replaceOp(op, {scalarResult}); + return success(); + } + + finalizeTensorAtomicResults(op, tensorTy, rewriter, resultVals, valueElemTy, + b, threadPred, *targetInfo, getTypeConverter()); + return success(); + } +}; + struct GSanTensorDescInfoOpConversion : public ConvertOpToLLVMPattern { public: @@ -389,6 +702,8 @@ void mlir::triton::populateGSanToLLVMPatterns( const TargetInfoBase &targetInfo) { patterns.add(typeConverter); patterns.add(typeConverter); + patterns.add(typeConverter, targetInfo); + patterns.add(typeConverter, targetInfo); patterns.add(typeConverter, axisInfoAnalysis, targetInfo); } diff --git a/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp index 18afb5607b9f..0d12450a7302 100644 --- a/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp @@ -404,7 +404,7 @@ class GlobalSanitizerPass } module.walk([&](Operation *op) { - OpBuilder b(op); + IRRewriter b(op); mlir::TypeSwitch(op) .Case([&](tt::LoadOp op) { ExperimentalGSanTensorAccessOp::create( @@ -436,12 +436,25 @@ class GlobalSanitizerPass }) .Case([&](ttng::AsyncTMAScatterOp op) { instrumentAsyncTMAScatter(op); + }) + .Case([&](tt::AtomicRMWOp op) { + auto newOp = ExperimentalGSanAtomicRMWOp::create( + b, op.getLoc(), op.getType(), op.getAtomicRmwOp(), op.getPtr(), + op.getVal(), op.getMask(), op.getSem(), op.getScope()); + newOp->setAttrs(op->getAttrs()); + b.replaceOp(op, newOp); + }) + .Case([&](tt::AtomicCASOp op) { + auto newOp = ExperimentalGSanAtomicCASOp::create( + b, op.getLoc(), op.getType(), op.getPtr(), op.getCmp(), + op.getVal(), op.getSem(), op.getScope()); + newOp->setAttrs(op->getAttrs()); + b.replaceOp(op, newOp); + }) + .Case([&](ttg::WarpSpecializeOp op) { + op->setAttr(kDisableSetMaxRegisterAttr, builder.getUnitAttr()); }); }); - - module.walk([&](ttg::WarpSpecializeOp op) { - op->setAttr(kDisableSetMaxRegisterAttr, builder.getUnitAttr()); - }); } }; diff --git a/python/test/gsan/test_gsan.py b/python/test/gsan/test_gsan.py index 0cfe910b93b8..9b143491dc99 100644 --- a/python/test/gsan/test_gsan.py +++ b/python/test/gsan/test_gsan.py @@ -12,7 +12,7 @@ from triton._internal_testing import is_blackwell, is_cuda, is_ampere_or_newer from triton.experimental.gsan import create_mem_pool from triton._C.libtriton.gsan_testing import AtomicScope, SHADOW_GRANULARITY_BYTES, ScalarClock -from triton.experimental.gsan._testing_utils import (load_one_i32, shadow_cell_from_address, store_one_i32, +from triton.experimental.gsan._testing_utils import (atomic_poll, load_one_i32, shadow_cell_from_address, store_one_i32, thread_state_from_smid) @@ -24,6 +24,85 @@ def with_gsan(fresh_knobs): yield +def _clock_buffer_snapshot_idx(token: int, state, tid: int) -> int: + return (token % state.clock_buffer_size) * state.num_threads + tid + + +ATOMIC_SCOPE_CASES = ( + pytest.param("cta", AtomicScope.CTA, id="scope-cta"), + pytest.param("gpu", AtomicScope.GPU, id="scope-gpu"), + pytest.param("sys", AtomicScope.SYSTEM, id="scope-sys"), +) + +ATOMIC_SEMANTIC_CASES = ( + pytest.param("relaxed", False, id="sem-relaxed"), + pytest.param("acquire", False, id="sem-acquire"), + pytest.param("release", True, id="sem-release"), + pytest.param("acq_rel", True, id="sem-acq-rel"), +) + +RELEASE_SEMANTIC_CASES = ( + pytest.param("release", id="sem-release"), + pytest.param("acq_rel", id="sem-acq-rel"), +) + +ACQUIRE_SEMANTIC_CASES = ( + pytest.param("acquire", id="sem-acquire"), + pytest.param("acq_rel", id="sem-acq-rel"), +) + + +def _assert_atomic_rmw_shadow(real_address: int, expected_scope: AtomicScope, *, is_release: bool) -> None: + cell = shadow_cell_from_address(real_address) + tid = cell.write_clock.thread_id + state = thread_state_from_smid(tid) + + if is_release: + token = cell.write_clock.epoch + snapshot_idx = _clock_buffer_snapshot_idx(token, state, tid) + published_epoch = state.clock_buffer[snapshot_idx] + + assert cell.write_clock == ScalarClock(token, tid, expected_scope, is_release=True) + assert token == state.clock_buffer_head + assert state.clock_buffer_dirty + assert cell.read_clocks[0] == ScalarClock(published_epoch, tid, expected_scope) + assert state.vector_clock[tid] == published_epoch + 1 + else: + epoch = state.vector_clock[tid] + assert cell.write_clock == ScalarClock(epoch, tid, expected_scope) + assert cell.read_clocks[0] == ScalarClock(epoch, tid, expected_scope) + + assert cell.num_reads == 1 + + +def _assert_atomic_read_only_shadow(real_address: int, expected_scope: AtomicScope) -> None: + cell = shadow_cell_from_address(real_address) + tid = cell.read_clocks[0].thread_id + epoch = thread_state_from_smid(tid).vector_clock[tid] + + assert cell.write_clock == ScalarClock(0, 0, AtomicScope.NON_ATOMIC) + assert cell.read_clocks[0] == ScalarClock(epoch, tid, expected_scope) + assert cell.num_reads == 1 + + +def _assert_cross_sm_sync(payload_ptr: torch.Tensor, flag_ptr: torch.Tensor, expected_scope: AtomicScope) -> None: + payload_cell = shadow_cell_from_address(payload_ptr.data_ptr()) + flag_cell = shadow_cell_from_address(flag_ptr.data_ptr()) + producer_tid = payload_cell.write_clock.thread_id + producer_epoch = payload_cell.write_clock.epoch + consumer_tid = payload_cell.read_clocks[0].thread_id + consumer_state = thread_state_from_smid(consumer_tid) + + assert flag_cell.write_clock.scope == expected_scope + assert flag_cell.write_clock.is_release + assert consumer_state.vector_clock[producer_tid] >= producer_epoch + + +def _assert_no_gsan_runtime_output(capfd) -> None: + captured = capfd.readouterr() + assert "GSanLibrary.cu" not in captured.out + captured.err + + @pytest.mark.skipif(not is_cuda(), reason="GSan requires CUDA") def test_load_store_updates_shadow(with_gsan): target = torch.zeros(1, dtype=torch.int32, device="cuda") @@ -83,6 +162,158 @@ def test_gluon_warp_specialize_completes(with_gsan): torch.testing.assert_close(out, expected) +@triton.jit +def atomic_add_kernel(ptr, sem: tl.constexpr, scope: tl.constexpr = "gpu"): + tl.atomic_add(ptr, 1, sem=sem, scope=scope) + + +@triton.jit +def atomic_cas_kernel(ptr, out_ptr, expect, sem: tl.constexpr, scope: tl.constexpr = "gpu"): + old = tl.atomic_cas(ptr, expect, 2, sem=sem, scope=scope) + tl.store(out_ptr, old) + + +@triton.jit +def _cross_sm_atomic_sync_kernel(payload_ptr, flag_ptr, out_ptr, producer_sem: tl.constexpr, consumer_sem: tl.constexpr, + scope: tl.constexpr): + pid = tl.program_id(0) + if pid == 0: + tl.store(payload_ptr, 1000) + tl.atomic_xchg(flag_ptr, 1, sem=producer_sem, scope=scope) + elif pid == 1: + atomic_poll(flag_ptr, 1, sem=consumer_sem, scope=scope) + result = tl.load(payload_ptr) + tl.store(out_ptr, result) + + +@triton.jit +def _transitive_atomic_sync_kernel(payload_ptr, flag0_ptr, flag1_ptr, out_ptr, release_sem: tl.constexpr, + acquire_sem: tl.constexpr, scope: tl.constexpr): + pid = tl.program_id(0) + if pid == 0: + tl.store(payload_ptr, 1000) + tl.atomic_xchg(flag0_ptr, 1, sem=release_sem, scope=scope) + elif pid == 1: + atomic_poll(flag0_ptr, 1, sem=acquire_sem, scope=scope) + tl.atomic_xchg(flag1_ptr, 1, sem=release_sem, scope=scope) + elif pid == 2: + atomic_poll(flag1_ptr, 1, sem=acquire_sem, scope=scope) + result = tl.load(payload_ptr) + tl.store(out_ptr, result) + + +@pytest.mark.skipif(not is_cuda(), reason="GSan requires CUDA") +@pytest.mark.parametrize("scope, expected_scope", ATOMIC_SCOPE_CASES) +@pytest.mark.parametrize("sem, is_release", ATOMIC_SEMANTIC_CASES) +def test_atomic_add_updates_atomic_shadow(with_gsan, sem, is_release, scope, expected_scope): + target = torch.zeros(1, dtype=torch.int32, device="cuda") + + atomic_add_kernel[(1, )](target, sem=sem, scope=scope, num_warps=1) + assert target.item() == 1 + + _assert_atomic_rmw_shadow(target.data_ptr(), expected_scope, is_release=is_release) + + +@pytest.mark.skipif(not is_cuda(), reason="GSan requires CUDA") +@pytest.mark.parametrize("scope, expected_scope", ATOMIC_SCOPE_CASES) +@pytest.mark.parametrize("sem, _", ATOMIC_SEMANTIC_CASES) +def test_atomic_cas_failed_only_records_read(with_gsan, sem, _, scope, expected_scope): + target = torch.zeros(1, dtype=torch.int32, device="cuda") + out = torch.zeros(1, dtype=torch.int32, device="cuda") + + atomic_cas_kernel[(1, )](target, out, expect=1, sem=sem, scope=scope, num_warps=1) + + assert target.item() == 0 + assert out.item() == 0 + + _assert_atomic_read_only_shadow(target.data_ptr(), expected_scope) + + +@pytest.mark.skipif(not is_cuda(), reason="GSan requires CUDA") +@pytest.mark.parametrize("scope, expected_scope", ATOMIC_SCOPE_CASES) +@pytest.mark.parametrize("sem, is_release", ATOMIC_SEMANTIC_CASES) +def test_atomic_cas_success_updates_atomic_shadow(with_gsan, sem, is_release, scope, expected_scope): + target = torch.zeros(1, dtype=torch.int32, device="cuda") + out = torch.zeros(1, dtype=torch.int32, device="cuda") + + atomic_cas_kernel[(1, )](target, out, expect=0, sem=sem, scope=scope, num_warps=1) + + assert target.item() == 2 + assert out.item() == 0 + + _assert_atomic_rmw_shadow(target.data_ptr(), expected_scope, is_release=is_release) + + +@pytest.mark.skipif(not is_cuda(), reason="GSan requires CUDA") +@pytest.mark.parametrize("scope, expected_scope", ATOMIC_SCOPE_CASES[1:]) +@pytest.mark.parametrize("producer_sem", RELEASE_SEMANTIC_CASES) +@pytest.mark.parametrize("consumer_sem", ACQUIRE_SEMANTIC_CASES) +def test_atomic_release_acquire_synchronizes_cross_sm(with_gsan, capfd, producer_sem, consumer_sem, scope, + expected_scope): + payload = torch.zeros(1, dtype=torch.int32, device="cuda") + flags = torch.zeros(1, dtype=torch.int32, device="cuda") + out = torch.full((1, ), -1, dtype=torch.int32, device="cuda") + _cross_sm_atomic_sync_kernel[(2, )]( + payload, + flags, + out, + producer_sem=producer_sem, + consumer_sem=consumer_sem, + scope=scope, + num_warps=1, + ) + torch.cuda.synchronize() + + assert out.item() == 1000 + + _assert_cross_sm_sync(payload, flags, expected_scope) + _assert_no_gsan_runtime_output(capfd) + + +@pytest.mark.skipif(not is_cuda(), reason="GSan requires CUDA") +@pytest.mark.parametrize("scope, expected_scope", ATOMIC_SCOPE_CASES[1:]) +@pytest.mark.parametrize("release_sem", RELEASE_SEMANTIC_CASES) +@pytest.mark.parametrize("acquire_sem", ACQUIRE_SEMANTIC_CASES) +def test_atomic_release_acquire_transitively_synchronizes_cross_sm(with_gsan, capfd, release_sem, acquire_sem, scope, + expected_scope): + payload = torch.zeros(1, dtype=torch.int32, device="cuda") + flag0 = torch.zeros(1, dtype=torch.int32, device="cuda") + flag1 = torch.zeros(1, dtype=torch.int32, device="cuda") + out = torch.full((1, ), -1, dtype=torch.int32, device="cuda") + _transitive_atomic_sync_kernel[(3, )]( + payload, + flag0, + flag1, + out, + release_sem=release_sem, + acquire_sem=acquire_sem, + scope=scope, + num_warps=1, + ) + torch.cuda.synchronize() + + assert out.item() == 1000 + + payload_cell = shadow_cell_from_address(payload.data_ptr()) + flag1_cell = shadow_cell_from_address(flag1.data_ptr()) + producer_tid = payload_cell.write_clock.thread_id + producer_epoch = payload_cell.write_clock.epoch + + relay_state = thread_state_from_smid(flag1_cell.write_clock.thread_id) + snapshot_idx = _clock_buffer_snapshot_idx(flag1_cell.write_clock.epoch, relay_state, producer_tid) + + assert flag1_cell.write_clock.scope == expected_scope + assert flag1_cell.write_clock.is_release + assert relay_state.clock_buffer[snapshot_idx] >= producer_epoch + + consumer_tid = payload_cell.read_clocks[0].thread_id + consumer_state = thread_state_from_smid(consumer_tid) + + assert consumer_state.vector_clock[producer_tid] >= producer_epoch + + _assert_no_gsan_runtime_output(capfd) + + @triton.jit def _write_blocks_kernel(ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) diff --git a/python/test/gsan/test_gsan_failures.py b/python/test/gsan/test_gsan_failures.py index 6c801946de31..e40e389f5425 100644 --- a/python/test/gsan/test_gsan_failures.py +++ b/python/test/gsan/test_gsan_failures.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import inspect from pathlib import Path @@ -10,21 +11,31 @@ from triton._internal_testing import is_blackwell, is_cuda, run_in_process from triton.experimental.gsan import create_mem_pool +from triton.experimental.gsan._testing_utils import atomic_poll from triton.tools.tensor_descriptor import TensorDescriptor pytestmark = pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") +RELEASE_ACQUIRE_SYNC_CASES = ( + pytest.param("release", "acquire", id="release-acquire"), + pytest.param("release", "acq_rel", id="release-acq-rel"), + pytest.param("acq_rel", "acquire", id="acq-rel-acquire"), + pytest.param("acq_rel", "acq_rel", id="acq-rel-acq-rel"), +) -@triton.jit -def nanosleep(duration): - duration = tl.to_tensor(duration) - tl.inline_asm_elementwise("nanosleep.u32 $1; mov.b32 $0, 0;", "=r, r", [duration], tl.int32, is_pure=False, pack=1) - +CROSS_SM_SEMANTIC_MISMATCH_CASES = ( + pytest.param("relaxed", "acquire", "gpu", id="producer-relaxed-consumer-acquire-scope-gpu"), + pytest.param("relaxed", "acquire", "sys", id="producer-relaxed-consumer-acquire-scope-sys"), + pytest.param("release", "relaxed", "gpu", id="producer-release-consumer-relaxed-scope-gpu"), + pytest.param("release", "relaxed", "sys", id="producer-release-consumer-relaxed-scope-sys"), +) -@triton.jit -def atomic_poll(counter_ptr, expected): - while tl.atomic_add(counter_ptr, 0, sem="relaxed") < expected: - nanosleep(100) +TRANSITIVE_RELAY_MISMATCH_CASES = ( + pytest.param("release", "relaxed", "gpu", id="relay-relaxed-scope-gpu"), + pytest.param("release", "relaxed", "sys", id="relay-relaxed-scope-sys"), + pytest.param("acq_rel", "release", "gpu", id="relay-release-scope-gpu"), + pytest.param("acq_rel", "release", "sys", id="relay-release-scope-sys"), +) @triton.jit @@ -62,6 +73,47 @@ def _waw_kernel(ptr, scratch_ptr, counter_ptr): tl.store(ptr, 2) +@triton.jit +def _cross_sm_atomic_sync_kernel(payload_ptr, flag_ptr, counter_ptr, scratch_ptr, producer_sem: tl.constexpr, + consumer_sem: tl.constexpr, scope: tl.constexpr): + pid = tl.program_id(0) + if pid == 0: + tl.store(payload_ptr, 1000) + tl.atomic_xchg(flag_ptr, 1, sem=producer_sem, scope=scope) + tl.atomic_add(counter_ptr, 1, sem="relaxed") + elif pid == 1: + atomic_poll(counter_ptr, 1) + ready = 0 + while ready != 1: + ready = tl.atomic_add(flag_ptr, 0, sem=consumer_sem, scope=scope) + result = tl.load(payload_ptr) + tl.store(scratch_ptr, result) + + +@triton.jit +def _transitive_atomic_sync_kernel(payload_ptr, flag0_ptr, flag1_ptr, counter_ptr, scratch_ptr, + release_sem: tl.constexpr, relay_sem: tl.constexpr, scope: tl.constexpr): + pid = tl.program_id(0) + if pid == 0: + tl.store(payload_ptr, 1000) + tl.atomic_xchg(flag0_ptr, 1, sem=release_sem, scope=scope) + tl.atomic_add(counter_ptr, 1, sem="relaxed") + elif pid == 1: + atomic_poll(counter_ptr, 1) + ready = 0 + while ready != 1: + ready = tl.atomic_add(flag0_ptr, 0, sem=relay_sem, scope=scope) + tl.atomic_xchg(flag1_ptr, 1, sem=release_sem, scope=scope) + tl.atomic_add(counter_ptr, 1, sem="relaxed") + elif pid == 2: + atomic_poll(counter_ptr, 2) + ready = 0 + while ready != 1: + ready = tl.atomic_add(flag1_ptr, 0, sem="acquire", scope=scope) + result = tl.load(payload_ptr) + tl.store(scratch_ptr, result) + + @triton.jit def _tma_raw_kernel(ptr, scratch_ptr, counter_ptr, m_size, n_size, row_idx, col_idx, stride_0, BLOCK: tl.constexpr): pid = tl.program_id(0) @@ -116,176 +168,267 @@ def _host_tma_scatter_war_kernel(target_ptr, target_desc, x_offsets_ptr, src_ptr tl.atomic_add(counter_ptr, 1, sem="relaxed") else: atomic_poll(counter_ptr, 1) + x_offsets = tl.load(x_offsets_ptr + tl.arange(0, BLOCK_X)) indices_x = tl.arange(0, BLOCK_X)[:, None] * src_stride_0 indices_y = tl.arange(0, BLOCK_Y)[None, :] * src_stride_1 values = tl.load(src_ptr + indices_x + indices_y) - x_offsets = tl.load(x_offsets_ptr + tl.arange(0, BLOCK_X)) target_desc.scatter(values, x_offsets, y_offset) -def _run_case(case: str) -> None: +def _cuda_byte_allocator(size: int, _align: int, _stream): + return torch.empty(size, dtype=torch.int8, device="cuda") + + +def run_with_gsan(fn): + + @functools.wraps(fn) + def wrapped(*args, **kwargs) -> None: + triton.knobs.compilation.instrumentation_mode = "gsan" + pool = create_mem_pool() + with torch.cuda.use_mem_pool(pool): + fn(*args, **kwargs) + + return wrapped + + +@run_with_gsan +def _run_raw_case() -> None: + target = torch.zeros(1, dtype=torch.int32, device="cuda") + scratch = torch.zeros(1, dtype=torch.int32, device="cuda") + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + _raw_kernel[(2, )](target, scratch, counter, num_warps=1) + + +@run_with_gsan +def _run_war_case() -> None: + target = torch.zeros(1, dtype=torch.int32, device="cuda") + scratch = torch.zeros(1, dtype=torch.int32, device="cuda") + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + _war_kernel[(2, )](target, scratch, counter, num_warps=1) + + +@run_with_gsan +def _run_waw_case() -> None: + target = torch.zeros(1, dtype=torch.int32, device="cuda") + scratch = torch.zeros(1, dtype=torch.int32, device="cuda") + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + _waw_kernel[(2, )](target, scratch, counter, num_warps=1) + + +@run_with_gsan +def _run_tma_raw_case() -> None: block = 32 m_size = 35 n_size = 37 padded_n = 40 row_idx = 5 col_idx = 8 - gather_block_x = 8 - gather_block_y = 8 - gather_m_size = 11 - gather_n_size = 13 - gather_padded_m = 16 - gather_padded_n = 16 - gather_row_idx = 5 - gather_y_offset = 8 - gather_x_offsets = [5, 7, 9, 10, 1, 3, 11, 13] - pool = create_mem_pool() - with torch.cuda.use_mem_pool(pool): - if case == "tma_raw": - target_storage = torch.zeros((m_size, padded_n), dtype=torch.int32, device="cuda") - target = target_storage[:, :n_size] - scratch = torch.zeros(1, dtype=torch.int32, device="cuda") - - def alloc_fn(size: int, _align: int, _stream): - return torch.empty(size, dtype=torch.int8, device="cuda") - - triton.set_allocator(alloc_fn) - elif case == "host_tma_war": - target_storage = torch.zeros((m_size, padded_n), dtype=torch.int32, device="cuda") - scratch_storage = torch.zeros_like(target_storage) - target = target_storage[:, :n_size] - scratch = scratch_storage[:, :n_size] - target_desc = TensorDescriptor.from_tensor(target, [block, block]) - scratch_desc = TensorDescriptor.from_tensor(scratch, [block, block]) - elif case in {"host_tma_gather_war", "host_tma_scatter_war"}: - target_storage = torch.zeros((gather_padded_m, gather_padded_n), dtype=torch.int32, device="cuda") - target = target_storage[:gather_m_size, :gather_n_size] - target_desc = TensorDescriptor.from_tensor(target, [1, gather_block_y]) - x_offsets = torch.tensor(gather_x_offsets, dtype=torch.int32, device="cuda") - if case == "host_tma_gather_war": - scratch = torch.zeros((gather_block_x, gather_block_y), dtype=torch.int32, device="cuda") - else: - src = torch.arange(1, gather_block_x * gather_block_y + 1, dtype=torch.int32, - device="cuda").reshape(gather_block_x, gather_block_y) - scratch = torch.zeros(1, dtype=torch.int32, device="cuda") - else: - target = torch.zeros(1, dtype=torch.int32, device="cuda") - scratch = torch.zeros(1, dtype=torch.int32, device="cuda") - counter = torch.zeros(1, dtype=torch.int32, device="cuda") - - triton.knobs.compilation.instrumentation_mode = "gsan" - kernel = globals()[f"_{case}_kernel"] - if case == "tma_raw": - kernel[(2, )](target, scratch, counter, m_size, n_size, row_idx, col_idx, target.stride(0), BLOCK=block, - num_warps=4) - elif case == "host_tma_war": - kernel[(2, )](target, target_desc, scratch_desc, counter, row_idx, col_idx, target.stride(0), num_warps=4) - elif case == "host_tma_gather_war": - kernel[(2, )](target, target_desc, x_offsets, scratch, counter, gather_row_idx, gather_y_offset, - target.stride(0), scratch.stride(0), scratch.stride(1), BLOCK_X=gather_block_x, num_warps=4) - elif case == "host_tma_scatter_war": - kernel[(2, )](target, target_desc, x_offsets, src, src.stride(0), src.stride(1), scratch, counter, - gather_row_idx, gather_y_offset, target.stride(0), BLOCK_X=gather_block_x, num_warps=4) - else: - kernel[(2, )](target, scratch, counter, num_warps=1) - - -CASE_INFO = { - "raw": { - "error": "Read after write race detected", - "function": _raw_kernel.fn, - "marker": "value = tl.load(ptr)", - }, - "war": { - "error": "Write after read race detected", - "function": _war_kernel.fn, - "marker": "tl.store(ptr, 1)", - }, - "waw": { - "error": "Write after write race detected", - "function": _waw_kernel.fn, - "marker": "tl.store(ptr, 2)", - }, - "tma_raw": { - "error": "Read after write race detected", - "function": _tma_raw_kernel.fn, - "marker": "value = tl.load(ptr + row_idx * stride_0 + col_idx)", - }, - "host_tma_war": { - "error": "Write after read race detected", - "function": _host_tma_war_kernel.fn, - "marker": "tl.store(target_ptr + row_idx * stride_0 + col_idx, 1)", - }, - "host_tma_gather_war": { - "error": "Write after read race detected", - "function": _host_tma_gather_war_kernel.fn, - "marker": "tl.store(target_ptr + row_idx * stride_0 + y_offset, 1)", - }, - "host_tma_scatter_war": { - "error": "Write after read race detected", - "function": _host_tma_scatter_war_kernel.fn, - "marker": "target_desc.scatter(values, x_offsets, y_offset)", - }, -} - - -def _expected_file_line(case: str) -> str: - source_lines, starting_line = inspect.getsourcelines(CASE_INFO[case]["function"]) - markers = CASE_INFO[case]["marker"] - if isinstance(markers, str): - markers = (markers, ) - - matches = [] - for marker in markers: - for line_offset, line in enumerate(source_lines): - if marker in line: - matches.append(f"{Path(__file__).name}:{starting_line + line_offset}") - break - else: - raise AssertionError(f"Could not find marker {marker!r} for case {case!r}") - return matches[0] if len(matches) == 1 else tuple(matches) - - -def _run_failure_case(case: str) -> None: + + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + target_storage = torch.zeros((m_size, padded_n), dtype=torch.int32, device="cuda") + target = target_storage[:, :n_size] + scratch = torch.zeros(1, dtype=torch.int32, device="cuda") + triton.set_allocator(_cuda_byte_allocator) + _tma_raw_kernel[(2, )](target, scratch, counter, m_size, n_size, row_idx, col_idx, target.stride(0), BLOCK=block) + + +@run_with_gsan +def _run_host_tma_war_case() -> None: + block = 32 + m_size = 35 + n_size = 37 + padded_n = 40 + row_idx = 5 + col_idx = 8 + + target_storage = torch.zeros((m_size, padded_n), dtype=torch.int32, device="cuda") + scratch_storage = torch.zeros_like(target_storage) + target = target_storage[:, :n_size] + scratch = scratch_storage[:, :n_size] + target_desc = TensorDescriptor.from_tensor(target, [block, block]) + scratch_desc = TensorDescriptor.from_tensor(scratch, [block, block]) + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + _host_tma_war_kernel[(2, )](target, target_desc, scratch_desc, counter, row_idx, col_idx, target.stride(0)) + + +@run_with_gsan +def _run_host_tma_gather_war_case() -> None: + block_x = 8 + block_y = 8 + m_size = 11 + n_size = 13 + padded_m = 16 + padded_n = 16 + row_idx = 5 + y_offset = 8 + x_offsets_values = [5, 7, 9, 10, 1, 3, 11, 13] + + target_storage = torch.zeros((padded_m, padded_n), dtype=torch.int32, device="cuda") + target = target_storage[:m_size, :n_size] + x_offsets = torch.tensor(x_offsets_values, dtype=torch.int32, device="cuda") + target_desc = TensorDescriptor.from_tensor(target, [1, block_y]) + scratch = torch.zeros((block_x, block_y), dtype=torch.int32, device="cuda") + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + _host_tma_gather_war_kernel[(2, )](target, target_desc, x_offsets, scratch, counter, row_idx, y_offset, + target.stride(0), scratch.stride(0), scratch.stride(1), BLOCK_X=block_x) + + +@run_with_gsan +def _run_host_tma_scatter_war_case() -> None: + block_x = 8 + block_y = 8 + m_size = 11 + n_size = 13 + padded_m = 16 + padded_n = 16 + row_idx = 5 + y_offset = 8 + x_offsets_values = [5, 7, 9, 10, 1, 3, 11, 13] + + target_storage = torch.zeros((padded_m, padded_n), dtype=torch.int32, device="cuda") + target = target_storage[:m_size, :n_size] + x_offsets = torch.tensor(x_offsets_values, dtype=torch.int32, device="cuda") + target_desc = TensorDescriptor.from_tensor(target, [1, block_y]) + src = torch.arange(1, block_x * block_y + 1, dtype=torch.int32, device="cuda").reshape(block_x, block_y) + scratch = torch.zeros(1, dtype=torch.int32, device="cuda") + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + _host_tma_scatter_war_kernel[(2, )](target, target_desc, x_offsets, src, src.stride(0), src.stride(1), scratch, + counter, row_idx, y_offset, target.stride(0), BLOCK_X=block_x) + + +@run_with_gsan +def _run_cross_sm_atomic_sync_case(producer_sem: str, consumer_sem: str, scope: str) -> None: + payload = torch.zeros(1, dtype=torch.int32, device="cuda") + flags = torch.zeros(1, dtype=torch.int32, device="cuda") + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + scratch = torch.full((1, ), -1, dtype=torch.int32, device="cuda") + _cross_sm_atomic_sync_kernel[(2, )]( + payload, + flags, + counter, + scratch, + producer_sem=producer_sem, + consumer_sem=consumer_sem, + scope=scope, + num_warps=1, + ) + + +@run_with_gsan +def _run_transitive_atomic_sync_case(release_sem: str, relay_sem: str, scope: str) -> None: + payload = torch.zeros(1, dtype=torch.int32, device="cuda") + flag0 = torch.zeros(1, dtype=torch.int32, device="cuda") + flag1 = torch.zeros(1, dtype=torch.int32, device="cuda") + counter = torch.zeros(1, dtype=torch.int32, device="cuda") + scratch = torch.full((1, ), -1, dtype=torch.int32, device="cuda") + _transitive_atomic_sync_kernel[(3, )]( + payload, + flag0, + flag1, + counter, + scratch, + release_sem=release_sem, + relay_sem=relay_sem, + scope=scope, + num_warps=1, + ) + + +def _expected_file_line(source_function, marker: str) -> str: + source_lines, starting_line = inspect.getsourcelines(source_function) + for line_offset, line in enumerate(source_lines): + if marker in line: + return f"{Path(__file__).name}:{starting_line + line_offset}" + raise AssertionError(f"Could not find marker {marker!r} for function {source_function!r}") + + +def _run_failure_case(case: str, *, runner, source_function, marker: str, error: str, runner_args=(), + runner_kwargs=None) -> None: if torch.cuda.device_count() < 1: pytest.skip("requires at least 1 CUDA device") - result = run_in_process(_run_case, (case, )) + if runner_kwargs is None: + runner_kwargs = {} + + result = run_in_process(runner, runner_args, runner_kwargs) print(result.driver_stderr_output) assert isinstance(result.exc, RuntimeError), (f"case={case} completed without the expected GSan failure\n" f"exc={result.exc!r}\n" f"driver stderr:\n{result.driver_stderr_output}") assert "GSanLibrary.cu" not in result.driver_stderr_output assert Path(__file__).name in result.driver_stderr_output - assert _expected_file_line(case) in result.driver_stderr_output - assert CASE_INFO[case]["error"] in result.driver_stderr_output + assert _expected_file_line(source_function, marker) in result.driver_stderr_output + assert error in result.driver_stderr_output def test_read_after_write(): - _run_failure_case("raw") + _run_failure_case("raw", runner=_run_raw_case, source_function=_raw_kernel.fn, marker="value = tl.load(ptr)", + error="Read after write race detected") def test_write_after_read(): - _run_failure_case("war") + _run_failure_case("war", runner=_run_war_case, source_function=_war_kernel.fn, marker="tl.store(ptr, 1)", + error="Write after read race detected") def test_write_after_write(): - _run_failure_case("waw") + _run_failure_case("waw", runner=_run_waw_case, source_function=_waw_kernel.fn, marker="tl.store(ptr, 2)", + error="Write after write race detected") def test_tma_read_after_write(): - _run_failure_case("tma_raw") + _run_failure_case("tma_raw", runner=_run_tma_raw_case, source_function=_tma_raw_kernel.fn, + marker="value = tl.load(ptr + row_idx * stride_0 + col_idx)", + error="Read after write race detected") def test_host_tma_write_after_read(): - _run_failure_case("host_tma_war") + _run_failure_case("host_tma_war", runner=_run_host_tma_war_case, source_function=_host_tma_war_kernel.fn, + marker="tl.store(target_ptr + row_idx * stride_0 + col_idx, 1)", + error="Write after read race detected") @pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") def test_host_tma_gather_write_after_read(): - _run_failure_case("host_tma_gather_war") + _run_failure_case("host_tma_gather_war", runner=_run_host_tma_gather_war_case, + source_function=_host_tma_gather_war_kernel.fn, + marker="tl.store(target_ptr + row_idx * stride_0 + y_offset, 1)", + error="Write after read race detected") @pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") def test_host_tma_scatter_write_after_read(): - _run_failure_case("host_tma_scatter_war") + _run_failure_case("host_tma_scatter_war", runner=_run_host_tma_scatter_war_case, + source_function=_host_tma_scatter_war_kernel.fn, + marker="target_desc.scatter(values, x_offsets, y_offset)", error="Write after read race detected") + + +@pytest.mark.parametrize("producer_sem, consumer_sem, scope", CROSS_SM_SEMANTIC_MISMATCH_CASES) +def test_cross_sm_semantic_mismatch_read_after_write(producer_sem, consumer_sem, scope): + _run_failure_case(f"cross_sm_semantic_mismatch_{producer_sem}_{consumer_sem}_{scope}", + runner=_run_cross_sm_atomic_sync_case, runner_args=(producer_sem, consumer_sem, scope), + source_function=_cross_sm_atomic_sync_kernel.fn, marker="result = tl.load(payload_ptr)", + error="Read after write race detected") + + +@pytest.mark.parametrize("producer_sem, consumer_sem", RELEASE_ACQUIRE_SYNC_CASES) +def test_cross_sm_cta_scope_read_after_write(producer_sem, consumer_sem): + _run_failure_case(f"cross_sm_cta_scope_{producer_sem}_{consumer_sem}", runner=_run_cross_sm_atomic_sync_case, + runner_args=(producer_sem, consumer_sem, "cta"), source_function=_cross_sm_atomic_sync_kernel.fn, + marker="ready = tl.atomic_add(flag_ptr, 0, sem=consumer_sem, scope=scope)", + error="Read after write race detected") + + +@pytest.mark.parametrize("release_sem, relay_sem, scope", TRANSITIVE_RELAY_MISMATCH_CASES) +def test_transitive_release_acquire_requires_middle_acquire(release_sem, relay_sem, scope): + _run_failure_case(f"transitive_sync_{release_sem}_{relay_sem}_{scope}", runner=_run_transitive_atomic_sync_case, + runner_args=(release_sem, relay_sem, scope), source_function=_transitive_atomic_sync_kernel.fn, + marker="result = tl.load(payload_ptr)", error="Read after write race detected") + + +@pytest.mark.parametrize("release_sem, relay_sem", RELEASE_ACQUIRE_SYNC_CASES) +def test_transitive_cta_scope_read_after_write(release_sem, relay_sem): + _run_failure_case(f"transitive_cta_scope_{release_sem}_{relay_sem}", runner=_run_transitive_atomic_sync_case, + runner_args=(release_sem, relay_sem, "cta"), source_function=_transitive_atomic_sync_kernel.fn, + marker="ready = tl.atomic_add(flag0_ptr, 0, sem=relay_sem, scope=scope)", + error="Read after write race detected") diff --git a/python/test/gsan/test_symmetric_memory.py b/python/test/gsan/test_symmetric_memory.py index 68ea476d823a..a6cfb07ef4d0 100644 --- a/python/test/gsan/test_symmetric_memory.py +++ b/python/test/gsan/test_symmetric_memory.py @@ -7,11 +7,13 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +import triton +import triton.language as tl -from triton._internal_testing import is_cuda +from triton._internal_testing import is_cuda, run_in_process from triton.experimental.gsan import symmetric_memory from triton.experimental.gsan._allocator import get_runtime_state_layout -from triton.experimental.gsan._testing_utils import shadow_tensor_for +from triton.experimental.gsan._testing_utils import atomic_poll, shadow_tensor_for from triton.experimental.gsan._utils import uint8_cuda_tensor_from_ptr @@ -33,6 +35,27 @@ def _local_vector_clocks(device_index: int) -> tuple[torch.Tensor, dict[str, int return clocks, layout +@triton.jit +def _single_cta_atomic_sync_kernel(counter_ptr, payload_ptr, peer_payload_ptr, num_ready_ptr, seen_peer_ptr, + payload_value, num_gpus): + tl.store(payload_ptr, payload_value) + + num_ready = tl.atomic_add(counter_ptr, 1, sem="acq_rel", scope="sys") + if num_ready != num_gpus - 1: + atomic_poll(counter_ptr, num_gpus, sem="acquire", scope="sys") + + seen_peer = tl.load(peer_payload_ptr) + tl.store(num_ready_ptr, num_ready) + tl.store(seen_peer_ptr, seen_peer) + + +@triton.jit +def _single_cta_no_atomic_sync_kernel(payload_ptr, peer_payload_ptr, seen_peer_ptr, payload_value): + tl.store(payload_ptr, payload_value) + seen_peer = tl.load(peer_payload_ptr) + tl.store(seen_peer_ptr, seen_peer) + + def _run_symmetric_memory_checks(rank: int, world_size: int) -> None: dev = torch.device(f"cuda:{rank}") torch.cuda.set_device(dev) @@ -166,6 +189,74 @@ def _run_subgroup_symmetric_memory_checks(rank: int) -> None: dist.destroy_process_group(subgroup) +def _run_single_cta_atomic_sync_check(rank: int, world_size: int) -> None: + dev = torch.device(f"cuda:{rank}") + torch.cuda.set_device(dev) + + peer = (rank + 1) % world_size + state = symmetric_memory.empty((2, ), dtype=torch.int32, device=dev) + state.zero_() + + hdl = symmetric_memory.rendezvous(state, group=dist.group.WORLD) + counter = hdl.get_buffer(0, (1, ), state.dtype, storage_offset=0) + peer_payload = hdl.get_buffer(peer, (1, ), state.dtype, storage_offset=1) + local_payload = state[1:] + num_ready = torch.full((1, ), -1, dtype=torch.int32, device=dev) + seen_peer = torch.full((1, ), -1, dtype=torch.int32, device=dev) + + hdl.barrier(channel=0) + _single_cta_atomic_sync_kernel[(1, )]( + counter, + local_payload, + peer_payload, + num_ready, + seen_peer, + rank + 1, + world_size, + num_warps=1, + ) + torch.cuda.synchronize() + + assert 0 <= int(num_ready.item()) < world_size + assert int(seen_peer.item()) == peer + 1 + + all_num_ready = [None] * world_size + dist.all_gather_object(all_num_ready, int(num_ready.item())) + if rank == 0: + assert sorted(all_num_ready) == list(range(world_size)) + assert int(state[0].item()) == world_size + + dist.barrier() + torch.cuda.synchronize() + hdl.close() + hdl.close() + + +def _run_single_cta_no_atomic_sync_check(rank: int, world_size: int) -> None: + triton.knobs.compilation.instrumentation_mode = "gsan" + + dev = torch.device(f"cuda:{rank}") + torch.cuda.set_device(dev) + + peer = (rank + 1) % world_size + payload = symmetric_memory.empty((1, ), dtype=torch.int32, device=dev) + payload.zero_() + + hdl = symmetric_memory.rendezvous(payload, group=dist.group.WORLD) + peer_payload = hdl.get_buffer(peer, payload.shape, payload.dtype) + seen_peer = torch.full((1, ), -1, dtype=torch.int32, device=dev) + + hdl.barrier(channel=0) + _single_cta_no_atomic_sync_kernel[(1, )]( + payload, + peer_payload, + seen_peer, + rank + 1, + num_warps=1, + ) + torch.cuda.synchronize() + + def _distributed_worker(rank: int, world_size: int, master_port: int, run_subgroup_check: bool) -> None: dev = f"cuda:{rank}" os.environ["WORLD_SIZE"] = str(world_size) @@ -182,6 +273,44 @@ def _distributed_worker(rank: int, world_size: int, master_port: int, run_subgro dist.destroy_process_group() +def _distributed_worker_single_cta_atomic_sync(rank: int, world_size: int, master_port: int) -> None: + dev = f"cuda:{rank}" + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size, device_id=torch.device(dev)) + try: + _run_single_cta_atomic_sync_check(rank, world_size) + dist.barrier() + finally: + dist.destroy_process_group() + + +def _distributed_worker_single_cta_no_atomic_sync(rank: int, world_size: int, master_port: int) -> None: + dev = f"cuda:{rank}" + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + torch.cuda.set_device(dev) + try: + _run_single_cta_no_atomic_sync_check(rank, world_size) + dist.barrier() + finally: + dist.destroy_process_group() + + +def _run_single_cta_no_atomic_sync_failure_case() -> None: + world_size = 2 + master_port = _get_free_tcp_port() + mp.spawn( + _distributed_worker_single_cta_no_atomic_sync, + args=(world_size, master_port), + nprocs=world_size, + join=True, + ) + + @pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") def test_gsan_symmetric_memory_rendezvous(): if torch.cuda.device_count() < 2: @@ -212,6 +341,31 @@ def test_gsan_symmetric_memory_rendezvous_subgroup_without_global_zero(): ) +@pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") +def test_gsan_symmetric_memory_single_cta_atomic_sync(): + if torch.cuda.device_count() < 2: + pytest.skip("requires 2 CUDA devices") + + world_size = 2 + master_port = _get_free_tcp_port() + mp.spawn( + _distributed_worker_single_cta_atomic_sync, + args=(world_size, master_port), + nprocs=world_size, + join=True, + ) + + +@pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") +def test_gsan_symmetric_memory_single_cta_no_atomic_sync_fails(): + if torch.cuda.device_count() < 2: + pytest.skip("requires 2 CUDA devices") + + result = run_in_process(_run_single_cta_no_atomic_sync_failure_case, env={"CUDA_LAUNCH_BLOCKING": "1"}) + assert result.exc is not None + assert "race detected" in result.driver_stderr_output + + def _run_triton_kernels_convert_dp_to_ep_with_gsan_pool(rank: int, world_size: int) -> None: from triton_kernels.distributed import (SymmetricMemoryPool, convert_dp_to_ep, make_expt_assignment, make_expt_dict_uniform) diff --git a/python/triton/experimental/gsan/_testing_utils.py b/python/triton/experimental/gsan/_testing_utils.py index b468044c81b2..f9696f6dd764 100644 --- a/python/triton/experimental/gsan/_testing_utils.py +++ b/python/triton/experimental/gsan/_testing_utils.py @@ -10,6 +10,18 @@ from ._utils import uint8_cuda_tensor_from_ptr +@triton.jit +def nanosleep(duration): + duration = tl.to_tensor(duration) + tl.inline_asm_elementwise("nanosleep.u32 $1; mov.b32 $0, 0;", "=r, r", [duration], tl.int32, is_pure=False, pack=1) + + +@triton.jit +def atomic_poll(ptr, expect, sem: tl.constexpr = "relaxed", scope: tl.constexpr = "gpu"): + while tl.atomic_add(ptr, 0, sem=sem, scope=scope) != expect: + nanosleep(100) + + def shadow_cell_tensor_from_address(real_address: int, *, device_index: int | None = None) -> torch.Tensor: if device_index is None: device_index = torch.cuda.current_device() diff --git a/python/triton/experimental/gsan/src/GSan.h b/python/triton/experimental/gsan/src/GSan.h index d8a6b3104757..1c4d1fc0e39c 100644 --- a/python/triton/experimental/gsan/src/GSan.h +++ b/python/triton/experimental/gsan/src/GSan.h @@ -31,6 +31,9 @@ struct alignas(4) ScalarClock { epoch_t epoch; thread_id_t threadId : 12; // Supports 4096 threads AtomicScope scope : 2; + // For a release write, the epoch is actually an index into the thread's + // circular clock buffer where the full vector clock is stored. + bool isRelease : 1; }; static constexpr int kMaxThreads = 1 << 12; static_assert(sizeof(ScalarClock) == 4); @@ -84,10 +87,18 @@ struct ThreadState { epoch_t vectorClock[]; }; +static constexpr int kMaxAtomicShadowCells = 3; + +struct AtomicEventState { + ThreadState *threadState; + ShadowCell *cells[kMaxAtomicShadowCells]; + uint8_t numCells; +}; + // Place the thread state for each device at a fixed stride for ease of // address calculation. static constexpr uintptr_t kPerDeviceStateStride = 1ull << 30; -static constexpr uintptr_t kMaxGPUs = 16; +static constexpr uintptr_t kMaxGPUs = 32; static constexpr uintptr_t kGlobalsReserveSize = kPerDeviceStateStride * kMaxGPUs; @@ -118,4 +129,8 @@ inline GSAN_HOST_DEVICE bool isGsanManaged(uintptr_t addr, return getReserveBaseFromAddress(addr) == reserveBase; } +inline GSAN_HOST_DEVICE bool isAtomicScope(AtomicScope scope) { + return scope != AtomicScope::NonAtomic; +} + } // namespace gsan diff --git a/python/triton/experimental/gsan/src/GSanLibrary.cu b/python/triton/experimental/gsan/src/GSanLibrary.cu index 85304014dad1..f90f47c63783 100644 --- a/python/triton/experimental/gsan/src/GSanLibrary.cu +++ b/python/triton/experimental/gsan/src/GSanLibrary.cu @@ -1,6 +1,7 @@ #include "GSan.h" #include "Hash.cuh" +#include #include #include #include @@ -30,6 +31,13 @@ namespace gsan { namespace { static constexpr uint32_t writerFlag = 1u << 31; +enum class AtomicSem : uint8_t { + Relaxed = 1, + Acquire = 2, + Release = 3, + AcquireRelease = 4, +}; + __device__ void rwLockAcquireRead(uint32_t &lock) { uint32_t old = __scoped_atomic_fetch_add(&lock, 1, __ATOMIC_ACQUIRE, __MEMORY_SCOPE_WRKGRP); @@ -81,10 +89,27 @@ __device__ thread_id_t getDeviceThreadId(GlobalState *globals, uint32_t smid) { return static_cast(deviceIdx * globals->numSms + smid); } +__device__ uintptr_t getThreadStateBaseAddress(uintptr_t globalsAddr) { + uintptr_t stateBase = globalsAddr; + stateBase = roundUp(stateBase + sizeof(GlobalState), alignof(ThreadState)); + return stateBase; +} + +__device__ ThreadState *getThreadStateById(GlobalState *globals, + thread_id_t tid) { + uint32_t deviceIdx = tid / globals->numSms; + uint32_t smid = tid % globals->numSms; + uintptr_t stateBase = static_cast(globals->globalsBase) + + deviceIdx * kPerDeviceStateStride; + stateBase = getThreadStateBaseAddress(stateBase); + auto stateStride = getThreadStateStrideBytes(globals); + return reinterpret_cast(stateBase + stateStride * smid); +} + __device__ ThreadState *getThreadState(GlobalState *globals) { uint32_t smid = getSmId(); - uintptr_t stateBase = reinterpret_cast(globals); - stateBase = roundUp(stateBase + sizeof(GlobalState), alignof(ThreadState)); + uintptr_t stateBase = + getThreadStateBaseAddress(reinterpret_cast(globals)); auto stateStride = getThreadStateStrideBytes(globals); auto *state = reinterpret_cast(stateBase + stateStride * smid); @@ -103,6 +128,97 @@ __device__ ThreadState *getThreadState(GlobalState *globals) { return state; } +__device__ epoch_t *getClockBufferBase(ThreadState *state) { + auto *globals = getGlobalState(state); + return state->vectorClock + globals->numThreads; +} + +__device__ epoch_t *getClockBufferSlot(ThreadState *state, epoch_t token, + Location loc) { + assert_msg(loc, token != 0, "Invalid GSan clock token"); + assert_msg(loc, token <= state->clockBufferHead, "Future GSan clock token"); + auto *globals = getGlobalState(state); + assert_msg(loc, state->clockBufferHead - token < globals->clockBufferSize, + "GSan clock buffer token overwritten"); + uint32_t slot = token % globals->clockBufferSize; + return getClockBufferBase(state) + slot * globals->numThreads; +} + +__device__ epoch_t publishClockBuffer(ThreadState *state, Location loc) { + auto *globals = getGlobalState(state); + uint32_t nextHead = state->clockBufferHead + 1; + assert_msg(loc, nextHead <= std::numeric_limits::max(), + "GSan clock buffer token overflowed"); + epoch_t *slot = + getClockBufferBase(state) + + ((nextHead - 1) % globals->clockBufferSize) * globals->numThreads; + for (int i = 0; i < globals->numThreads; ++i) + slot[i] = state->vectorClock[i]; + state->clockBufferHead = nextHead; + state->clockBufferDirty = 0; + return static_cast(nextHead); +} + +__device__ AtomicSem decodeAtomicSem(uint32_t sem) { + switch (sem) { + case 1: + return AtomicSem::Relaxed; + case 2: + return AtomicSem::Acquire; + case 3: + return AtomicSem::Release; + case 4: + return AtomicSem::AcquireRelease; + default: + assert(false || !"Unexpected atomic semantic type"); + } +} + +__device__ AtomicScope decodeAtomicScope(uint32_t scope) { + switch (scope) { + case 1: + return AtomicScope::GPU; + case 2: + return AtomicScope::CTA; + case 3: + return AtomicScope::System; + default: + assert(false || !"Unexpected atomic scope"); + } +} + +__device__ bool hasAcquire(AtomicSem sem) { + return sem == AtomicSem::Acquire || sem == AtomicSem::AcquireRelease; +} + +__device__ bool hasRelease(AtomicSem sem) { + return sem == AtomicSem::Release || sem == AtomicSem::AcquireRelease; +} + +__device__ bool scopeCoversPair(AtomicScope scope, thread_id_t lhs, + thread_id_t rhs, GlobalState *globals) { + switch (scope) { + case AtomicScope::CTA: + return lhs == rhs; + case AtomicScope::GPU: + return lhs / globals->numSms == rhs / globals->numSms; + case AtomicScope::System: + return true; + case AtomicScope::NonAtomic: + return false; + } + return false; +} + +__device__ bool areAtomicScopesCompatible(AtomicScope lhs, thread_id_t lhsTid, + AtomicScope rhs, thread_id_t rhsTid, + GlobalState *globals) { + if (!isAtomicScope(lhs) || !isAtomicScope(rhs)) + return false; + return scopeCoversPair(lhs, lhsTid, rhsTid, globals) && + scopeCoversPair(rhs, lhsTid, rhsTid, globals); +} + __device__ void initThread(GlobalState *globals, Location loc) { auto *state = getThreadState(globals); @@ -116,6 +232,7 @@ __device__ void initThread(GlobalState *globals, Location loc) { assert_msg(loc, clock[tid] != std::numeric_limits::max(), "Vector clock overflowed"); clock[tid] += 1; + state->clockBufferDirty = 1; } } @@ -150,21 +267,161 @@ __device__ void releaseShadow(ShadowCell *cell) { __MEMORY_SCOPE_SYSTEM); } +__device__ epoch_t appendClockBufferSnapshot(ThreadState *state, + const epoch_t *snapshot, + Location loc) { + auto *globals = getGlobalState(state); + assert_msg(loc, globals->clockBufferSize != 0, + "GSan clock buffer size must be non-zero"); + uint32_t curHead = state->clockBufferHead; + uint32_t nextHead = curHead + 1; + assert_msg(loc, nextHead <= std::numeric_limits::max(), + "GSan clock buffer token overflowed"); + epoch_t *slot = getClockBufferBase(state) + + (nextHead % globals->clockBufferSize) * globals->numThreads; + for (int i = 0; i < globals->numThreads; ++i) + slot[i] = snapshot[i]; + state->clockBufferHead = nextHead; + return static_cast(nextHead); +} + +__device__ epoch_t publishCurrentVectorClock(ThreadState *state, Location loc) { + if (state->clockBufferDirty) { + auto token = appendClockBufferSnapshot(state, state->vectorClock, loc); + state->clockBufferDirty = 0; + return token; + } + return state->clockBufferHead; +} + +__device__ const epoch_t *getSnapshotForWrite(ThreadState *state, + const ScalarClock &write, + Location loc) { + if (!write.isRelease) + return nullptr; + auto *writerState = getThreadStateById(getGlobalState(state), write.threadId); + return getClockBufferSlot(writerState, write.epoch, loc); +} + +__device__ epoch_t propagateClockBufferSnapshot(ThreadState *state, + const ScalarClock &write, + Location loc) { + auto *snapshot = getSnapshotForWrite(state, write, loc); + assert_msg(loc, snapshot != nullptr, "Invalid GSan propagated clock token"); + auto token = appendClockBufferSnapshot(state, snapshot, loc); + state->clockBufferDirty = 1; + return token; +} + +__device__ void incrementThreadEpoch(ThreadState *state, Location loc) { + auto tid = state->threadId; + auto *clock = state->vectorClock; + assert_msg(loc, clock[tid] != std::numeric_limits::max(), + "Vector clock overflowed"); + clock[tid] += 1; + state->clockBufferDirty = 1; +} + +__device__ bool dominatesSnapshot(ThreadState *state, const epoch_t *snapshot) { + auto *globals = getGlobalState(state); + for (int i = 0; i < globals->numThreads; ++i) { + if (state->vectorClock[i] < snapshot[i]) + return false; + } + return true; +} + +__device__ bool clockHappensBefore(ThreadState *state, const ScalarClock &clock, + Location loc) { + if (clock.epoch == 0) + return true; + if (const epoch_t *snapshot = getSnapshotForWrite(state, clock, loc)) + return dominatesSnapshot(state, snapshot); + return state->vectorClock[clock.threadId] >= clock.epoch; +} + +__device__ void assertOrderedOrCompatible(ThreadState *state, + AtomicScope currentScope, + const ScalarClock &prior, + Location loc, const char *message) { + if (prior.epoch == 0) + return; + if (isAtomicScope(currentScope) && + areAtomicScopesCompatible(currentScope, state->threadId, prior.scope, + prior.threadId, getGlobalState(state))) { + return; + } + assert_msg(loc, clockHappensBefore(state, prior, loc), message); +} + +__device__ void maybeMergeAcquire(ThreadState *state, AtomicScope currentScope, + const ScalarClock &prior, Location loc) { + if (!prior.isRelease) + return; + if (!areAtomicScopesCompatible(currentScope, state->threadId, prior.scope, + prior.threadId, getGlobalState(state))) { + return; + } + auto *snapshot = getSnapshotForWrite(state, prior, loc); + bool changed = false; + auto *globals = getGlobalState(state); + for (int i = 0; i < globals->numThreads; ++i) { + if (state->vectorClock[i] < snapshot[i]) { + state->vectorClock[i] = snapshot[i]; + changed = true; + } + } + if (changed) + state->clockBufferDirty = 1; +} + +__device__ ScalarClock makeScalarClock(ThreadState *state, AtomicScope scope) { + auto tid = state->threadId; + return ScalarClock{state->vectorClock[tid], tid, scope, false}; +} + +__device__ ScalarClock makePublishedClock(ThreadState *state, AtomicScope scope, + epoch_t token) { + return ScalarClock{token, state->threadId, scope, true}; +} + +__device__ void recordRead(ThreadState *state, ShadowCell *cell, + AtomicScope scope) { + auto numReads = cell->numReads; + if (numReads < std::numeric_limitsnumReads)>::max()) + ++cell->numReads; + + auto scalarClock = makeScalarClock(state, scope); + for (int iRead = 0; iRead < ShadowCell::kReadClockSize; ++iRead) { + auto readClock = cell->readClocks[iRead]; + if (readClock.threadId == state->threadId || readClock.epoch == 0) { + cell->readClocks[iRead] = scalarClock; + return; + } + } + + auto threadNumReads = __scoped_atomic_fetch_add( + &state->numReads, 1, __ATOMIC_RELAXED, __MEMORY_SCOPE_WRKGRP); + auto seed = getGlobalState(state)->rngSeed; + uint32_t rand = hash2x32(threadNumReads, state->threadId, seed); + rand = rand % numReads; + if (rand < ShadowCell::kReadClockSize) { + cell->readClocks[rand] = scalarClock; + } +} + __device__ void doWrite(ThreadState *state, ShadowCell *cell, Location loc) { - epoch_t *clock = state->vectorClock; // Check WAR for (int iRead = 0; iRead < ShadowCell::kReadClockSize; ++iRead) { - auto read = cell->readClocks[iRead]; - assert_msg(loc, clock[read.threadId] >= read.epoch, - "Write after read race detected"); + assertOrderedOrCompatible(state, AtomicScope::NonAtomic, + cell->readClocks[iRead], loc, + "Write after read race detected"); } // Check WAW - auto write = cell->writeClock; - assert_msg(loc, clock[write.threadId] >= write.epoch, - "Write after write race detected"); + assertOrderedOrCompatible(state, AtomicScope::NonAtomic, cell->writeClock, + loc, "Write after write race detected"); // Update write - auto tid = state->threadId; - cell->writeClock = ScalarClock{clock[tid], tid, AtomicScope::NonAtomic}; + cell->writeClock = makeScalarClock(state, AtomicScope::NonAtomic); } __device__ void writeRange(ThreadState *state, uintptr_t write_addr, int nBytes, @@ -201,37 +458,9 @@ __device__ void tensorStore(ThreadState *state, const char *stackPtr, } __device__ void doRead(ThreadState *state, ShadowCell *cell, Location loc) { - // Update read count - auto numReads = cell->numReads; - if (numReads < std::numeric_limitsnumReads)>::max()) - ++cell->numReads; - - epoch_t *clock = state->vectorClock; - // Check RAW - auto write = cell->writeClock; - assert_msg(loc, clock[write.threadId] >= write.epoch, - "Read after write race detected"); - - auto tid = state->threadId; - auto scalarClock = ScalarClock{clock[tid], tid, AtomicScope::NonAtomic}; - // First, try to update in-place - for (int iRead = 0; iRead < ShadowCell::kReadClockSize; ++iRead) { - auto readClock = cell->readClocks[iRead]; - if (readClock.threadId == tid || readClock.epoch == 0) { - cell->readClocks[iRead] = scalarClock; - return; - } - } - - // Otherwise, do stochastic replacement - auto threadNumReads = __scoped_atomic_fetch_add( - &state->numReads, 1, __ATOMIC_RELAXED, __MEMORY_SCOPE_WRKGRP); - auto seed = getGlobalState(state)->rngSeed; - uint32_t rand = hash2x32(threadNumReads, state->threadId, seed); - if ((rand >> 8) % numReads != 0) - return; - auto clockIdx = rand % ShadowCell::kReadClockSize; - cell->readClocks[clockIdx] = scalarClock; + assertOrderedOrCompatible(state, AtomicScope::NonAtomic, cell->writeClock, + loc, "Read after write race detected"); + recordRead(state, cell, AtomicScope::NonAtomic); } __device__ void readRange(ThreadState *state, uintptr_t read_addr, int nBytes, @@ -267,6 +496,128 @@ __device__ void tensorLoad(ThreadState *state, const char *stackPtr, int nElems, } } +__device__ void initAtomicEventState(AtomicEventState *event) { + event->threadState = nullptr; + event->numCells = 0; + for (auto &cell : event->cells) + cell = nullptr; +} + +__device__ void acquireAtomicShadowRange(ThreadState *state, + AtomicEventState *event, + uintptr_t address, int nBytes, + Location loc) { + auto range = roundRange(Range{address, address + nBytes}); + auto reserveBase = state->reserveBase; + uint8_t numCells = 0; + for (uintptr_t addr = range.start; addr < range.end; + addr += kShadowMemGranularityBytes) { + if (isGsanManaged(addr, reserveBase)) + ++numCells; + } + assert_msg(loc, numCells <= kMaxAtomicShadowCells, + "Atomic access spans too many GSan shadow cells"); + if (numCells == 0) + return; + + // FIXME: Deadlock risk. If two concurrent accesses have different types, they + // may partially acquire the shadow cells and block other threads from making + // progress. + rwLockAcquireWrite(state->lock); + event->threadState = state; + event->numCells = 0; + for (uintptr_t addr = range.start; addr < range.end; + addr += kShadowMemGranularityBytes) { + if (!isGsanManaged(addr, reserveBase)) + continue; + event->cells[event->numCells++] = acquireShadow(getShadowAddress(addr)); + } +} + +__device__ void releaseAtomicShadowRange(AtomicEventState *event) { + if (event->threadState == nullptr) + return; + for (uint8_t i = 0; i < event->numCells; ++i) + releaseShadow(event->cells[i]); + rwLockReleaseWrite(event->threadState->lock); + initAtomicEventState(event); +} + +__device__ void beginAtomicAccess(GlobalState *globals, AtomicEventState *event, + bool pred, uintptr_t address, int nBytes, + uint32_t semRaw, uint32_t scopeRaw, + Location loc) { + initAtomicEventState(event); + if (!pred) + return; + + auto *state = getThreadState(globals); + acquireAtomicShadowRange(state, event, address, nBytes, loc); + if (event->threadState == nullptr) + return; + + auto sem = decodeAtomicSem(semRaw); + auto scope = decodeAtomicScope(scopeRaw); + for (uint8_t i = 0; i < event->numCells; ++i) { + auto *cell = event->cells[i]; + auto write = cell->writeClock; + assertOrderedOrCompatible(state, scope, write, loc, + "Read after write race detected"); + recordRead(state, cell, scope); + } + if (hasAcquire(sem)) { + for (uint8_t i = 0; i < event->numCells; ++i) { + auto write = event->cells[i]->writeClock; + maybeMergeAcquire(state, scope, write, loc); + } + } +} + +__device__ void endAtomicAccess(AtomicEventState *event, bool pred, + bool didWrite, uint32_t semRaw, + uint32_t scopeRaw, Location loc) { + if (!pred || event->threadState == nullptr) + return; + + auto *state = event->threadState; + auto sem = decodeAtomicSem(semRaw); + auto scope = decodeAtomicScope(scopeRaw); + + if (didWrite) { + for (uint8_t i = 0; i < event->numCells; ++i) { + auto *cell = event->cells[i]; + for (int iRead = 0; iRead < ShadowCell::kReadClockSize; ++iRead) { + assertOrderedOrCompatible(state, scope, cell->readClocks[iRead], loc, + "Write after read race detected"); + } + assertOrderedOrCompatible(state, scope, cell->writeClock, loc, + "Write after write race detected"); + } + + ScalarClock newWriteClock; + if (hasRelease(sem)) { + auto token = publishCurrentVectorClock(state, loc); + newWriteClock = makePublishedClock(state, scope, token); + } else { + auto previousWrite = event->cells[0]->writeClock; + if (previousWrite.isRelease) { + auto token = propagateClockBufferSnapshot(state, previousWrite, loc); + newWriteClock = makePublishedClock(state, scope, token); + } else { + newWriteClock = makeScalarClock(state, scope); + } + } + + for (uint8_t i = 0; i < event->numCells; ++i) + event->cells[i]->writeClock = newWriteClock; + + if (hasRelease(sem)) + incrementThreadEpoch(state, loc); + } + + releaseAtomicShadowRange(event); +} + } // namespace } // namespace gsan @@ -294,3 +645,23 @@ __triton_gsan_store_tensor(void *globalState, const char *stackPtr, gsan::getThreadState(reinterpret_cast(globalState)); gsan::tensorStore(threadState, stackPtr, numElems, bytesPerElem, loc); } + +extern "C" __device__ void +__triton_gsan_atomic_begin_scalar(void *globalState, void *eventState, int pred, + uintptr_t address, int bytesPerElem, int sem, + int scope, const char *file, unsigned line) { + auto loc = gsan::Location{file, line}; + gsan::beginAtomicAccess( + reinterpret_cast(globalState), + reinterpret_cast(eventState), pred != 0, + address, bytesPerElem, sem, scope, loc); +} + +extern "C" __device__ void +__triton_gsan_atomic_end_scalar(void *eventState, int pred, int didWrite, + int sem, int scope, const char *file, + unsigned line) { + auto loc = gsan::Location{file, line}; + gsan::endAtomicAccess(reinterpret_cast(eventState), + pred != 0, didWrite != 0, sem, scope, loc); +} diff --git a/python/triton/experimental/gsan/src/gsan_testing.cc b/python/triton/experimental/gsan/src/gsan_testing.cc index 7ec4b826a585..209c1bb96d30 100644 --- a/python/triton/experimental/gsan/src/gsan_testing.cc +++ b/python/triton/experimental/gsan/src/gsan_testing.cc @@ -18,6 +18,7 @@ struct PyScalarClock { uint32_t epoch = 0; uint32_t threadId = 0; gsan::AtomicScope scope = gsan::AtomicScope::NonAtomic; + bool isRelease = false; }; struct PyShadowCell { @@ -85,7 +86,8 @@ std::string scalarClockStr(const PyScalarClock &c) { std::ostringstream oss; oss << "ScalarClock(epoch=" << static_cast(c.epoch) << ", thread_id=" << static_cast(c.threadId) - << ", scope=" << atomicScopeStr(c.scope) << ")"; + << ", scope=" << atomicScopeStr(c.scope) + << ", is_release=" << (c.isRelease ? "True" : "False") << ")"; return oss.str(); } @@ -151,6 +153,7 @@ PyScalarClock toPyScalarClock(const gsan::ScalarClock &clock) { out.epoch = clock.epoch; out.threadId = static_cast(clock.threadId); out.scope = clock.scope; + out.isRelease = clock.isRelease; return out; } @@ -242,15 +245,17 @@ void init_gsan_testing(py::module &&m) { [](gsan::AtomicScope scope) { return atomicScopeStr(scope); }); py::class_(m, "ScalarClock") - .def(py::init( - [](uint64_t epoch, uint64_t threadId, gsan::AtomicScope scope) { - PyScalarClock out; - out.epoch = static_cast(epoch); - out.threadId = static_cast(threadId); - out.scope = scope; - return out; - }), - py::arg("epoch"), py::arg("thread_id"), py::arg("scope")) + .def(py::init([](uint64_t epoch, uint64_t threadId, + gsan::AtomicScope scope, bool isRelease) { + PyScalarClock out; + out.epoch = static_cast(epoch); + out.threadId = static_cast(threadId); + out.scope = scope; + out.isRelease = isRelease; + return out; + }), + py::arg("epoch"), py::arg("thread_id"), py::arg("scope"), + py::arg("is_release") = false) .def_property_readonly( "epoch", [](const PyScalarClock &c) { return static_cast(c.epoch); }) @@ -263,11 +268,13 @@ void init_gsan_testing(py::module &&m) { .def_property_readonly( "scope_name", [](const PyScalarClock &c) { return atomicScopeName(c.scope); }) + .def_property_readonly("is_release", + [](const PyScalarClock &c) { return c.isRelease; }) .def( "__eq__", [](const PyScalarClock &lhs, const PyScalarClock &rhs) { return lhs.epoch == rhs.epoch && lhs.threadId == rhs.threadId && - lhs.scope == rhs.scope; + lhs.scope == rhs.scope && lhs.isRelease == rhs.isRelease; }, py::is_operator()) .def("__str__", [](const PyScalarClock &c) { return scalarClockStr(c); }) diff --git a/python/triton/experimental/gsan/symmetric_memory.py b/python/triton/experimental/gsan/symmetric_memory.py index 3c913d2a29e1..aaf923ea03e8 100644 --- a/python/triton/experimental/gsan/symmetric_memory.py +++ b/python/triton/experimental/gsan/symmetric_memory.py @@ -202,8 +202,8 @@ def barrier(self, channel: int = 0, timeout_ms: int = 0) -> None: if channel != 0: raise NotImplementedError("Only channel=0 is supported in GSan symmetric memory.") _ = timeout_ms - dist.barrier(group=self._group) if self._world_size > 1: + dist.barrier(group=self._group) _stream_sync.synchronize_process_group_barrier(self._device_index, self._peer_device_indices) dist.barrier(group=self._group) diff --git a/test/Conversion/tritongpu_to_llvm_gsan.mlir b/test/Conversion/tritongpu_to_llvm_gsan.mlir index d09c147b318b..2022780ce86d 100644 --- a/test/Conversion/tritongpu_to_llvm_gsan.mlir +++ b/test/Conversion/tritongpu_to_llvm_gsan.mlir @@ -24,6 +24,14 @@ module attributes {"ttg.instrumentation_mode" = "gsan", "ttg.num-ctas" = 1 : i32 tt.store %ptrs, %vals : tensor<128x!tt.ptr, #blocked> tt.return } + + // CHECK-LABEL: llvm.func @unmasked_atomic_add + // CHECK: llvm.call @__triton_gsan_atomic_begin_scalar + // CHECK: llvm.call @__triton_gsan_atomic_end_scalar + tt.func @unmasked_atomic_add(%ptr: !tt.ptr, %val: i32) { + %0 = tt.atomic_rmw add, relaxed, gpu, %ptr, %val : (!tt.ptr, i32) -> i32 + tt.return + } } // ----- diff --git a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/AtomicPTXBuilder.h b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/AtomicPTXBuilder.h new file mode 100644 index 000000000000..c53e67f0aea4 --- /dev/null +++ b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/AtomicPTXBuilder.h @@ -0,0 +1,157 @@ +#ifndef TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_ATOMICPTXBUILDER_H +#define TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_ATOMICPTXBUILDER_H + +#include "PTXAsmFormat.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +namespace mlir::triton::NVIDIA { + +inline std::string getPtxRegisterSizeCode(int size, bool isFloat) { + switch (size) { + case 1: + return "b"; + case 16: + return "h"; + case 32: + return isFloat ? "f" : "r"; + case 64: + return isFloat ? "d" : "l"; + case 128: + return "q"; + default: + llvm_unreachable("Unsupported register size"); + } +} + +inline FailureOr +emitPtxAtomicRMW(ConversionPatternRewriter &rewriter, Location loc, + Type valueElemTy, Value ptr, ArrayRef vals, + RMWOp rmwOpAttr, MemSemantic sem, MemSyncScope scope, + Value pred, unsigned vec = 1, unsigned packed = 1) { + assert((vec == 1 || packed == 1) && "packed or vec must be 1"); + assert(vals.size() == (vec > 1 ? vec : packed) && + "Expected atomic RMW operand count to match vectorization"); + + TritonLLVMOpBuilder b(loc, rewriter); + unsigned valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + Type packedTy = vec_ty(valueElemTy, packed); + + PTXBuilder ptxBuilderAtomicRMW; + std::string tyId = + getPtxRegisterSizeCode(valueElemNBits * packed, /*isFloat=*/false); + + PTXBuilder::Operand *dstOpr; + if (vec > 1) { + dstOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + dstOpr->listAppend( + ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true)); + } + } else { + dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + } + + auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(ptr, "l"); + + PTXBuilder::Operand *valOpr; + if (vec > 1) { + valOpr = ptxBuilderAtomicRMW.newListOperand(); + for (Value val : vals) + valOpr->listAppend(ptxBuilderAtomicRMW.newOperand(val, tyId)); + } else if (packed > 1) { + Value packedVal = b.undef(packedTy); + for (auto [idx, val] : llvm::enumerate(vals)) + packedVal = b.insert_element(packedTy, packedVal, val, b.i32_val(idx)); + valOpr = ptxBuilderAtomicRMW.newOperand(packedVal, tyId); + } else { + valOpr = ptxBuilderAtomicRMW.newOperand(vals.front(), tyId); + } + + auto &atom = ptxBuilderAtomicRMW.create("atom")->global().o( + stringifyMemSyncScope(scope).str()); + std::string rmwOp = stringifyRMWOp(rmwOpAttr).str(); + std::string sTy; + auto sBits = std::to_string(valueElemNBits); + switch (rmwOpAttr) { + case RMWOp::AND: + case RMWOp::OR: + case RMWOp::XOR: + case RMWOp::XCHG: + sTy = "b" + sBits; + break; + case RMWOp::ADD: + sTy = "u" + sBits; + break; + case RMWOp::FADD: + rmwOp = "add"; + rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); + sTy = (valueElemTy.isBF16() ? "bf" : "f") + sBits; + sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : ""; + break; + case RMWOp::MAX: + case RMWOp::MIN: + sTy = "s" + sBits; + break; + case RMWOp::UMAX: + rmwOp = "max"; + sTy = "u" + sBits; + break; + case RMWOp::UMIN: + rmwOp = "min"; + sTy = "u" + sBits; + break; + default: + return failure(); + } + + std::string semStr; + llvm::raw_string_ostream os(semStr); + os << sem; + atom.o(semStr).o(rmwOp).v(vec).o(sTy); + atom(dstOpr, ptrOpr, valOpr).maybePredicate(pred); + + Type retType; + if (vec > 1) { + SmallVector retTys(vec, valueElemTy); + retType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), retTys); + } else if (packed > 1) { + retType = packedTy; + } else { + retType = valueElemTy; + } + return ptxBuilderAtomicRMW.launch(rewriter, loc, retType); +} + +inline Value emitPtxAtomicCAS(ConversionPatternRewriter &rewriter, Location loc, + Type valueElemTy, Value ptr, Value cmp, Value val, + MemSemantic sem, MemSyncScope scope, Value pred) { + unsigned valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + PTXBuilder ptxBuilderAtomicCAS; + std::string tyId = getPtxRegisterSizeCode(valueElemNBits, /*isFloat=*/false); + auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true); + auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(ptr, "l"); + auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(cmp, tyId); + auto *valOpr = ptxBuilderAtomicCAS.newOperand(val, tyId); + auto &atom = *ptxBuilderAtomicCAS.create("atom"); + auto sTy = "b" + std::to_string(valueElemNBits); + std::string semStr; + llvm::raw_string_ostream os(semStr); + os << sem; + atom.global().o(semStr).o(stringifyMemSyncScope(scope).str()).o("cas").o(sTy); + atom(dstOpr, ptrOpr, cmpOpr, valOpr).maybePredicate(pred); + return ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); +} + +} // namespace mlir::triton::NVIDIA + +#endif // TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_ATOMICPTXBUILDER_H diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 5a9c020ad3fd..e6ed46e3542a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -6,6 +6,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" +#include "TritonNVIDIAGPUToLLVM/AtomicPTXBuilder.h" #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "PatternTritonGPUOpToLLVM.h" @@ -45,23 +46,6 @@ unsigned getCanonicalIndex(unsigned index, unsigned freeVarMask) { return index & ~freeVarMask; } -std::string getRegisterSizeCode(int size, bool is_float) { - switch (size) { - case 1: - return "b"; - case 16: - return "h"; - case 32: - return is_float ? "f" : "r"; - case 64: - return is_float ? "d" : "l"; - case 128: - return "q"; - default: - llvm_unreachable("Unsupported register size"); - } -} - Value createCachePolicy(triton::EvictionPolicy opEvict, ConversionPatternRewriter &rewriter, Location loc, int computeCapability) { @@ -565,6 +549,9 @@ struct AtomicCASOpConversion tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) : valueTy; auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + std::string tyId = + NVIDIA::getPtxRegisterSizeCode(valueElemNBits, /*isFloat=*/false); + std::string sTy = "b" + std::to_string(valueElemNBits); auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType()); Value threadPred = ttg::emitRedundantThreadPredicate(freeVarMasks, rewriter, @@ -584,28 +571,13 @@ struct AtomicCASOpConversion Value casVal = valElements[i]; Value casCmp = cmpElements[i]; Value casPtr = ptrElements[i]; - PTXBuilder ptxBuilderAtomicCAS; - std::string tyId = - valueElemNBits == 64 ? "l" : (valueElemNBits == 32 ? "r" : "h"); - auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true); - auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l"); - auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId); - auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, tyId); - auto &atom = *ptxBuilderAtomicCAS.create("atom"); - auto sTy = "b" + std::to_string(valueElemNBits); - std::string semStr; - llvm::raw_string_ostream os(semStr); - os << op.getSem(); - auto scope = stringifyMemSyncScope(op.getScope()).str(); - atom.global().o(semStr).o(scope).o("cas").o(sTy); - atom(dstOpr, ptrOpr, cmpOpr, valOpr).maybePredicate(threadPred); + Value old = NVIDIA::emitPtxAtomicCAS(rewriter, loc, valueElemTy, casPtr, + casCmp, casVal, op.getSem(), + op.getScope(), threadPred); if (tensorTy) { - auto retType = valueElemTy; - auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType); - resultVals[i] = ret; + resultVals[i] = old; } else { - auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); if (op.getResult().use_empty()) { rewriter.eraseOp(op); return success(); @@ -919,104 +891,17 @@ struct AtomicRMWOpConversion continue; } - std::string sTy; - PTXBuilder ptxBuilderAtomicRMW; - // 16-bit -> "h", 32-bit -> "r", 64-bit -> "l" - std::string tyId = - getRegisterSizeCode(valueElemNBits * packed, /*is_float=*/false); - - PTXBuilder::Operand *dstOpr; - if (vec > 1) { - dstOpr = ptxBuilderAtomicRMW.newListOperand(); - for (unsigned ii = 0; ii < vec; ++ii) { - dstOpr->listAppend( - ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true)); - } - } else { - dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); - } - - auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); - - PTXBuilder::Operand *valOpr; - if (vec > 1) { - valOpr = ptxBuilderAtomicRMW.newListOperand(); - for (unsigned ii = 0; ii < vec; ++ii) { - valOpr->listAppend( - ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId)); - } - } else if (packed > 1) { - Value rmwVal = b.undef(packedTy); - for (int ii = 0; ii < packed; ++ii) { - rmwVal = b.insert_element(packedTy, rmwVal, valElements[i + ii], - b.i32_val(ii)); - } - valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); - } else { - valOpr = ptxBuilderAtomicRMW.newOperand(valElements[i], tyId); - } - - auto scope = stringifyMemSyncScope(op.getScope()).str(); - auto &atom = ptxBuilderAtomicRMW.create("atom")->global().o(scope); - auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); - auto sBits = std::to_string(valueElemNBits); - switch (atomicRmwAttr) { - case RMWOp::AND: - sTy = "b" + sBits; - break; - case RMWOp::OR: - sTy = "b" + sBits; - break; - case RMWOp::XOR: - sTy = "b" + sBits; - break; - case RMWOp::ADD: - sTy = "u" + sBits; - break; - case RMWOp::FADD: - rmwOp = "add"; - rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); - sTy = (valueElemTy.isBF16() ? "bf" : "f") + sBits; - sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : ""; - break; - case RMWOp::MAX: - sTy = "s" + sBits; - break; - case RMWOp::MIN: - sTy = "s" + sBits; - break; - case RMWOp::UMAX: - rmwOp = "max"; - sTy = "u" + sBits; - break; - case RMWOp::UMIN: - rmwOp = "min"; - sTy = "u" + sBits; - break; - case RMWOp::XCHG: - sTy = "b" + sBits; - break; - default: + SmallVector rmwVals; + rmwVals.reserve(vec > 1 ? vec : packed); + for (unsigned ii = 0; ii < (vec > 1 ? vec : packed); ++ii) + rmwVals.push_back(valElements[i + ii]); + auto old = NVIDIA::emitPtxAtomicRMW(rewriter, loc, valueElemTy, rmwPtr, + rmwVals, atomicRmwAttr, op.getSem(), + op.getScope(), pred, vec, packed); + if (failed(old)) return failure(); - } - std::string semStr; - llvm::raw_string_ostream os(semStr); - os << op.getSem(); - atom.o(semStr).o(rmwOp).v(vec).o(sTy); if (tensorTy) { - atom(dstOpr, ptrOpr, valOpr).maybePredicate(pred); - Type retType; - if (vec > 1) { - SmallVector retTys(vec, valueElemTy); - retType = struct_ty(retTys); - } else if (packed > 1) { - retType = packedTy; - } else { - retType = valueElemTy; - } - - Value ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); - + Value ret = *old; if (vec > 1) { for (unsigned ii = 0; ii < vec; ++ii) { resultVals[i + ii] = b.extract_val(valueElemTy, ret, ii); @@ -1030,9 +915,6 @@ struct AtomicRMWOpConversion resultVals[i] = ret; } } else { - auto ASMReturnTy = void_ty(ctx); - atom(dstOpr, ptrOpr, valOpr).maybePredicate(pred); - auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); if (op.getResult().use_empty()) { rewriter.eraseOp(op); return success(); @@ -1041,7 +923,7 @@ struct AtomicRMWOpConversion op.getOperation()); atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); // Only threads with rmwMask = True store the result - targetInfo.storeShared(rewriter, loc, atomPtr, old, pred); + targetInfo.storeShared(rewriter, loc, atomPtr, *old, pred); createBarrier(rewriter, loc, numCTAs); Value ret = b.load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret});