From f08a2b31776d0447a73997d5f49cb505a1fe2a21 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 21 May 2026 19:02:11 -0700 Subject: [PATCH 1/6] Add Gluon async shared store support --- .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 30 +++ lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 27 +++ .../Transforms/ConSanNVIDIA.cpp | 11 ++ python/src/gluon_ir.cc | 5 + python/test/gluon/test_core.py | 29 +++ python/test/gluon/test_frontend.py | 16 ++ .../experimental/gluon/language/_semantic.py | 10 + .../language/nvidia/blackwell/__init__.py | 3 +- .../gluon/language/nvidia/hopper/__init__.py | 14 ++ test/Conversion/tritonnvidiagpu_to_llvm.mlir | 19 +- test/TritonNvidiaGPU/invalid.mlir | 27 +++ .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 178 ++++++++++++++++++ 12 files changed, 367 insertions(+), 2 deletions(-) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 006361d9fb47..552aa3604817 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -409,6 +409,36 @@ def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive", [ } +def TTNG_AsyncSharedStoreOp : TTNG_Op<"async_shared_store", [ + DeclareOpInterfaceMethods]> { + let summary = "store a distributed tensor to shared memory asynchronously"; + + let description = [{ + Store a distributed tensor into shared memory using PTX st.async.shared. + The store completion decrements the transaction count of `mbarrier`. + }]; + + let arguments = (ins + TT_Tensor:$src, + Arg]>:$dst, + Arg]>:$mbarrier + ); + + let assemblyFormat = [{ + $src `,` $dst `,` $mbarrier attr-dict `:` type($src) `->` + qualified(type($dst)) `,` qualified(type($mbarrier)) + }]; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return ::mlir::triton::nvidia_gpu::TargetFeatures(computeCapability) + .supportClusterOps(); + } + }]; +} + + def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, TMALoadLikeOpInterface]> { diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 7d80e0427ed1..d7c791ae01b1 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -257,6 +257,33 @@ Type ArriveBarrierOp::getPredicateOperandTypeLike() { return IntegerType::get(getContext(), 1); } +static LogicalResult verifyCompletionBarrierLayout(Operation *op, + Value barrier); + +// -- AsyncSharedStoreOp -- +LogicalResult AsyncSharedStoreOp::verify() { + if (gpu::lookupNumCTAs(getOperation()) < 2) + return emitOpError("requires at least two CTAs in the cluster"); + if (!getDst().getType().getMutableMemory()) + return emitOpError("cannot store into immutable memory"); + if (failed(triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyBarrierType(*this, getMbarrier().getType()))) + return failure(); + if (failed(verifyCompletionBarrierLayout(getOperation(), getMbarrier()))) + return failure(); + + unsigned bitwidth = getSrc().getType().getElementTypeBitWidth(); + if (bitwidth != 32 && bitwidth != 64) + return emitOpError("requires 32-bit or 64-bit element types"); + return success(); +} + +TypedValue AsyncSharedStoreOp::getBarrier() { + return getMbarrier(); +} + // -- FenceMBarrierInitReleaseClusterOp -- LogicalResult FenceMBarrierInitReleaseClusterOp::verify() { int numCTAs = triton::gpu::lookupNumCTAs(getOperation()); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index e8c6080fd710..af4052a5a5b2 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -230,6 +230,17 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, loadOp.getResult()); } + if (auto storeOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->barriers.push_back( + {storeOp.getBarrier(), nullptr, /*count=*/0, + MemEffectsOpInfo::BarrierTrackingMode::EffectWrites, + /*txCount=*/-static_cast(tti::getMemDescLength( + storeOp.getDst()))}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + storeOp.getDst()); + } if (auto tryCancelOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 4d98119243a2..0c54444039f1 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -691,6 +691,11 @@ void init_gluon_ir(py::module &&m) { [](GluonOpBuilder &self, Value memDesc, Value value) { self.create(value, memDesc); }) + .def("create_async_shared_store", + [](GluonOpBuilder &self, Value memDesc, Value value, + Value mbarrier) { + self.create(value, memDesc, mbarrier); + }) .def("create_local_load", [](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value { return self.create(resultTy, memDesc); diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 2cec543e4ceb..3be2782c5846 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -143,6 +143,35 @@ def test_local_store_transposed_cga_to_non_transposed_alloc(): torch.testing.assert_close(out, inp.T, atol=0, rtol=0) +@gluon.jit +def async_shared_store_kernel(out, BLOCK: ttgl.constexpr): + layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=[[0]]) + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]]) + + offsets = ttgl.arange(0, BLOCK, layout=layout) + values = offsets.to(ttgl.int32) + smem = ttgl.allocate_shared_memory(ttgl.int32, [BLOCK], shared_layout) + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + mbarrier.expect(bar, smem.nbytes_per_cta) + hopper.async_store(smem, values, bar) + mbarrier.wait(bar, phase=0, deps=[smem]) + result = smem.load(layout) + mbarrier.invalidate(bar) + ttgl.store(out + offsets, result) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_async_shared_store(): + block = 128 + out = torch.empty((block,), device="cuda", dtype=torch.int32) + + compiled = async_shared_store_kernel[(1, )](out, block, num_warps=4, num_ctas=2) + + assert "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes" in compiled.asm["ptx"] + torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.int32)) + + @gluon.jit def tma_kernel(desc): layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 1144802fabd6..c8575cab5aa1 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -635,6 +635,22 @@ def test_mbarrier(target): """) +@gluon.jit +def async_shared_store_kernel(): + layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=[[0]]) + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]]) + values = ttgl.arange(0, 128, layout=layout).to(ttgl.int32) + dst = ttgl.allocate_shared_memory(ttgl.int32, [128], shared_layout) + bar = mbarrier.allocate_mbarrier() + hopper.async_store(dst, values, bar) + + +@pytest.mark.parametrize("target", [HOPPER_TARGET, BLACKWELL_TARGET]) +def test_async_shared_store(target): + mod = run_parser(async_shared_store_kernel, *make_args(num_ctas=2), target=target) + assert "ttng.async_shared_store" in anonymize_ir(mod.str_nodebug()) + + @gluon.jit def tcgen05_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr): a = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout) diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index 902da6081771..3f7f87d6c652 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -276,6 +276,16 @@ def shared_store(self, mem_desc, value): lambda: f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match") self.builder.create_local_store(mem_desc.handle, value.handle) + def async_shared_store(self, mem_desc, value, mbarrier): + _check(isinstance(value, ttgl.tensor), lambda: f"expected 'value' to be a tensor, but got a {type(value)}") + _check(isinstance(mbarrier, ttgl.shared_memory_descriptor), + lambda: f"expected 'mbarrier' to be a shared_memory_descriptor, but got a {type(mbarrier)}") + _check(value.shape == mem_desc.shape, + lambda: f"source shape {value.shape} and destination shape {mem_desc.shape} must match") + _check(value.dtype == mem_desc.dtype, + lambda: f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match") + self.builder.create_async_shared_store(mem_desc.handle, value.handle, mbarrier.handle) + def _check_int_indices_and_normalize_axis(self, mem_desc, indices, axis): _check(isinstance(indices, ttgl.tensor), lambda: f"expected 'indices' to be a tensor, but got a {type(indices)}") diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index 3da1220133ee..da9d4fd7552a 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -9,7 +9,7 @@ from . import tma from . import clc -from ..hopper import fence_async_shared, mbarrier +from ..hopper import async_store, fence_async_shared, mbarrier from ..ampere import async_copy, mma_v2 from triton._C.libtriton import ir @@ -21,6 +21,7 @@ __all__ = [ "allocate_tensor_memory", "async_copy", + "async_store", "clc", "fence_async_shared", "mbarrier", diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py b/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py index b64216792530..c6171f7dab40 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "async_copy", + "async_store", "cluster", "fence_async_shared", "mbarrier", @@ -32,6 +33,19 @@ def fence_async_shared(cluster=False, _semantic=None): _semantic.builder.create_fence_async_shared(cluster) +@_core.builtin +def async_store(dst, value, mbarrier, _semantic=None): + """ + Store a tensor to shared memory asynchronously and signal an mbarrier on completion. + + Args: + dst (shared_memory_descriptor): Destination shared memory descriptor. + value (tensor): Tensor whose contents to store. + mbarrier (shared_memory_descriptor): Barrier signaled when the store completes. + """ + _semantic.async_shared_store(dst, value, mbarrier) + + class warpgroup_mma_accumulator_type(_core.base_type): tensor_type: _core.dtype diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index ccf2064eba81..f95267af47cd 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 -reconcile-unrealized-casts | FileCheck %s +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' -reconcile-unrealized-casts | FileCheck %s #shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory @@ -147,6 +147,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: async_shared_store + // CHECK: nvvm.mapa + // CHECK: nvvm.mapa + // CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 + tt.func @async_shared_store(%src: tensor<128xi32, #blocked>, %dst: !ttg.memdesc<128xi32, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<2xi64, #shared0, #smem, mutable>) { + ttng.async_shared_store %src, %dst, %mbarrier : tensor<128xi32, #blocked> -> !ttg.memdesc<128xi32, #shared1, #smem, mutable>, !ttg.memdesc<2xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + // TMA copy with barrier mask zero: barrier has no CGALayout -> shared::cta #shared0_cta = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}> diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index 67813f71b49d..de02bf697319 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -38,6 +38,33 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- +#blocked_i32 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared_i32 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_shared_store_requires_cluster(%src: tensor<128xi32, #blocked_i32>, %dst: !ttg.memdesc<128xi32, #shared_i32, #smem, mutable>, %bar: !ttg.memdesc<1xi64, #shared_i32, #smem, mutable>) { + // expected-error @+1 {{requires at least two CTAs in the cluster}} + ttng.async_shared_store %src, %dst, %bar : tensor<128xi32, #blocked_i32> -> !ttg.memdesc<128xi32, #shared_i32, #smem, mutable>, !ttg.memdesc<1xi64, #shared_i32, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked_f16 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#shared_f16 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrier_shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_shared_store_requires_wide_elements(%src: tensor<128xf16, #blocked_f16>, %dst: !ttg.memdesc<128xf16, #shared_f16, #smem, mutable>, %bar: !ttg.memdesc<2xi64, #barrier_shared, #smem, mutable>) { + // expected-error @+1 {{requires 32-bit or 64-bit element types}} + ttng.async_shared_store %src, %dst, %bar : tensor<128xf16, #blocked_f16> -> !ttg.memdesc<128xf16, #shared_f16, #smem, mutable>, !ttg.memdesc<2xi64, #barrier_shared, #smem, mutable> + tt.return + } +} + +// ----- + #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #scales = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> #tmem = #ttng.tensor_memory_scales_encoding<> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 5ab129d32b37..298a2aa5a5f3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -300,6 +300,182 @@ struct LocalStoreOpConversion const NVIDIA::TargetInfo &targetInfo; }; +static std::string getAsyncSharedStoreConstraint(unsigned bitwidth) { + switch (bitwidth) { + case 32: + return "r"; + case 64: + return "l"; + default: + llvm_unreachable("unsupported st.async.shared bitwidth"); + } +} + +static Value normalizeAsyncSharedStoreValue(Value value, Type valueTy, + unsigned bitwidth, + TritonLLVMOpBuilder &b) { + Type intTy = IntegerType::get(valueTy.getContext(), bitwidth); + if (isa(valueTy)) + return b.ptrtoint(intTy, value); + if (!valueTy.isInteger()) + return b.bitcast(value, intTy); + return value; +} + +static Value mapSharedToCluster(Location loc, Value ptr, Value ctaId, + ConversionPatternRewriter &rewriter) { + auto ptrTy = cast(ptr.getType()); + assert(ptrTy.getAddressSpace() == 3 && + "st.async.shared expects a shared-memory pointer"); + return NVVM::MapaOp::create(rewriter, loc, + ptr_ty(rewriter.getContext(), 7), ptr, ctaId); +} + +static void emitAsyncSharedStore(Location loc, ArrayRef vals, Value dst, + Value mbarrier, VectorType vecTy, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Type elemTy = vecTy.getElementType(); + unsigned vec = vecTy.getNumElements(); + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); + assert((elemBitwidth == 32 || elemBitwidth == 64) && + "st.async.shared only supports 32-bit and 64-bit elements"); + + if (vec * elemBitwidth > 128) { + assert(vec % (128 / elemBitwidth) == 0); + int maxVec = 128 / elemBitwidth; + auto splitVecTy = + VectorType::get({static_cast(maxVec)}, elemTy); + for (int i = 0; i < vec / maxVec; i++) { + auto newDst = b.gep(dst.getType(), elemTy, dst, b.i32_val(i * maxVec), + LLVM::GEPNoWrapFlags::inbounds); + emitAsyncSharedStore(loc, vals.slice(i * maxVec, maxVec), newDst, + mbarrier, splitVecTy, rewriter); + } + return; + } + assert(1 <= vec && vec <= 4); + + PTXBuilder builder; + auto st = builder.create("st.async") + ->o("weak") + .o("shared::cluster") + .o("mbarrier::complete_tx::bytes") + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); + auto *dstOpr = builder.newAddrOperand(dst, "r"); + auto *mbarrierOpr = builder.newAddrOperand(mbarrier, "r"); + auto constraint = getAsyncSharedStoreConstraint(elemBitwidth); + PTXBuilder::Operand *valueOpr; + if (vec == 1) { + valueOpr = builder.newOperand( + normalizeAsyncSharedStoreValue(vals.front(), elemTy, elemBitwidth, b), + constraint); + } else { + SmallVector> vecVals; + vecVals.reserve(vec); + for (Value value : vals) { + vecVals.push_back({normalizeAsyncSharedStoreValue( + value, elemTy, elemBitwidth, b), + constraint}); + } + valueOpr = builder.newListOperand(vecVals); + } + st(dstOpr, valueOpr, mbarrierOpr); + builder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +static void lowerAsyncSharedStore( + Location loc, MLIRContext *ctx, LinearLayout cvt, ArrayRef vals, + Type llvmElemTy, MemDescType dstTy, SharedMemoryObject dstMemObj, + Value mbarrierPtr, ConversionPatternRewriter &rewriter, + const NVIDIA::TargetInfo &targetInfo) { + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt); + if (!removeBroadcastSrc.isIdentity()) { + auto prmtCvt = removeBroadcastSrc.apply(cvt); + auto inVals = removeBroadcastSrc.apply(to_vector(vals)); + lowerAsyncSharedStore(loc, ctx, prmtCvt, inVals, llvmElemTy, dstTy, + dstMemObj, mbarrierPtr, rewriter, targetInfo); + return; + } + + auto affineOffset = dstMemObj.getShmemOffset(loc, rewriter, dstTy); + auto maskSpanAffineOffset = dstMemObj.getMaskSpanOffsets(dstTy); + std::optional maybeMaxVecElems; + SmallVector> paddingShifts; + if (triton::gpu::isPaddedEncoding(dstTy.getEncoding())) { + maybeMaxVecElems = triton::gpu::getMinInterval(dstTy.getEncoding()); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + paddingShifts = getPaddedSharedShifts(dstTy.getEncoding(), bitwidth, + /*offsetInBytes=*/true); + } + + SmallVector smemBases(dstMemObj.getBases().begin(), + dstMemObj.getBases().end()); + Value currentCTAId = targetInfo.getClusterCTAId(rewriter, loc); + auto emitSt = [&](RewriterBase &, Location storeLoc, ArrayRef values, + Value shmemAddr, int idx, VectorType vecTy, + std::optional ctaId) -> SmallVector { + Value targetCTAId = ctaId.value_or(currentCTAId); + Value dst = mapSharedToCluster(storeLoc, shmemAddr, targetCTAId, rewriter); + Value mbarrier = + mapSharedToCluster(storeLoc, mbarrierPtr, targetCTAId, rewriter); + emitAsyncSharedStore(storeLoc, + values.slice(idx, vecTy.getNumElements()), dst, + mbarrier, vecTy, rewriter); + return {}; + }; + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + lowerLdSt(loc, ctx, cvt, vals, llvmElemTy, smemBases, paddingShifts, + affineOffset, maskSpanAffineOffset, laneId, warpId, rewriter, + targetInfo, maybeMaxVecElems, emitSt); +} + +struct AsyncSharedStoreOpConversion + : public ConvertOpToLLVMPattern { + AsyncSharedStoreOpConversion(const LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern( + converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::AsyncSharedStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!triton::nvidia_gpu::AsyncSharedStoreOp::isSupported( + targetInfo.getComputeCapability()) || + targetInfo.getPtxVersion() < 81) + return op.emitError("requires cluster-capable SM90+ and PTX 8.1+"); + + auto loc = op.getLoc(); + MemDescType dstTy = op.getDst().getType(); + RankedTensorType srcTy = op.getSrc().getType(); + Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); + auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getDst(), llvmElemTy, rewriter); + auto mbarrierTy = op.getMbarrier().getType(); + auto mbarrierMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getMbarrier(), + typeConverter->convertType(mbarrierTy.getElementType()), rewriter); + + auto regLayout = toLinearLayout(srcTy); + auto sharedLayout = isPaddedEncoding(dstTy.getEncoding()) + ? paddedLinearLayout(dstTy) + : toLinearLayout(dstTy); + auto cvt = regLayout.invertAndCompose(sharedLayout); + auto values = unpackLLElements(loc, adaptor.getSrc(), rewriter); + lowerAsyncSharedStore(loc, op.getContext(), cvt, values, llvmElemTy, + dstTy, dstMemObj, mbarrierMemObj.getBase(), rewriter, + targetInfo); + rewriter.eraseOp(op); + return success(); + } + +private: + const NVIDIA::TargetInfo &targetInfo; +}; + struct LocalAtomicScatterRMWOpConversion : public ConvertOpToLLVMPattern { public: @@ -378,6 +554,8 @@ void mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns( benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo, benefit.getBenefit() + 1); + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo, benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo, From 9ed9c470c67786afdc26ad838584a7fb3f0ba5ed Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 21 May 2026 19:13:18 -0700 Subject: [PATCH 2/6] Support packed async shared stores --- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 26 ++++++- python/test/gluon/test_core.py | 29 +++++++ test/Conversion/tritonnvidiagpu_to_llvm.mlir | 16 ++++ test/TritonNvidiaGPU/invalid.mlir | 4 +- .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 77 ++++++++++++------- 5 files changed, 121 insertions(+), 31 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index d7c791ae01b1..035eaf9e1864 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -27,6 +27,7 @@ #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" @@ -39,6 +40,7 @@ #include "triton/Tools/StrUtil.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" using namespace mlir::triton::gpu; @@ -262,6 +264,7 @@ static LogicalResult verifyCompletionBarrierLayout(Operation *op, // -- AsyncSharedStoreOp -- LogicalResult AsyncSharedStoreOp::verify() { + // PTX defines weak shared::cluster st.async as UB for a one-CTA cluster. if (gpu::lookupNumCTAs(getOperation()) < 2) return emitOpError("requires at least two CTAs in the cluster"); if (!getDst().getType().getMutableMemory()) @@ -274,9 +277,26 @@ LogicalResult AsyncSharedStoreOp::verify() { if (failed(verifyCompletionBarrierLayout(getOperation(), getMbarrier()))) return failure(); - unsigned bitwidth = getSrc().getType().getElementTypeBitWidth(); - if (bitwidth != 32 && bitwidth != 64) - return emitOpError("requires 32-bit or 64-bit element types"); + auto srcTy = getSrc().getType(); + auto dstTy = getDst().getType(); + unsigned bitwidth = getIntOrFloatOrPtrBitWidth(srcTy.getElementType()); + if (bitwidth < 8 || bitwidth > 64 || !llvm::isPowerOf2_32(bitwidth)) + return emitOpError("requires 8-, 16-, 32-, or 64-bit element types"); + + auto regLayout = toLinearLayout(srcTy); + auto sharedLayout = isPaddedEncoding(dstTy.getEncoding()) + ? paddedLinearLayout(dstTy) + : toLinearLayout(dstTy); + auto cvt = regLayout.invertAndCompose(sharedLayout); + std::optional maybeMaxVecElems; + if (isPaddedEncoding(dstTy.getEncoding())) + maybeMaxVecElems = getMinInterval(dstTy.getEncoding()); + auto vectorization = + largestVectorisation(getContext(), cvt, bitwidth, maybeMaxVecElems); + unsigned elemsPerVec = vectorization.first; + if (elemsPerVec * bitwidth < 32) + return emitOpError("requires a layout vectorizing stores to at least 32 " + "bits"); return success(); } diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 3be2782c5846..58ce0f76aa70 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -161,6 +161,24 @@ def async_shared_store_kernel(out, BLOCK: ttgl.constexpr): ttgl.store(out + offsets, result) +@gluon.jit +def async_shared_store_f16_kernel(out, BLOCK: ttgl.constexpr): + layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0], cga_layout=[[0]]) + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]]) + + offsets = ttgl.arange(0, BLOCK, layout=layout) + values = offsets.to(ttgl.float16) + smem = ttgl.allocate_shared_memory(ttgl.float16, [BLOCK], shared_layout) + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + mbarrier.expect(bar, smem.nbytes_per_cta) + hopper.async_store(smem, values, bar) + mbarrier.wait(bar, phase=0, deps=[smem]) + result = smem.load(layout) + mbarrier.invalidate(bar) + ttgl.store(out + offsets, result) + + @pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") def test_async_shared_store(): block = 128 @@ -172,6 +190,17 @@ def test_async_shared_store(): torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.int32)) +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_async_shared_store_packed_f16(): + block = 256 + out = torch.empty((block,), device="cuda", dtype=torch.float16) + + compiled = async_shared_store_f16_kernel[(1, )](out, block, num_warps=4, num_ctas=2) + + assert "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32" in compiled.asm["ptx"] + torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.float16)) + + @gluon.jit def tma_kernel(desc): layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index f95267af47cd..c516e76ecfe2 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -164,6 +164,22 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: async_shared_store_f16 + // CHECK: llvm.bitcast {{.*}} : vector<2xf16> to i32 + // CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 + tt.func @async_shared_store_f16(%src: tensor<256xf16, #blocked>, %dst: !ttg.memdesc<256xf16, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<2xi64, #shared0, #smem, mutable>) { + ttng.async_shared_store %src, %dst, %mbarrier : tensor<256xf16, #blocked> -> !ttg.memdesc<256xf16, #shared1, #smem, mutable>, !ttg.memdesc<2xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + // TMA copy with barrier mask zero: barrier has no CGALayout -> shared::cta #shared0_cta = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}> diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index de02bf697319..4c4b8f0b3d7d 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -56,8 +56,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #barrier_shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @async_shared_store_requires_wide_elements(%src: tensor<128xf16, #blocked_f16>, %dst: !ttg.memdesc<128xf16, #shared_f16, #smem, mutable>, %bar: !ttg.memdesc<2xi64, #barrier_shared, #smem, mutable>) { - // expected-error @+1 {{requires 32-bit or 64-bit element types}} + tt.func public @async_shared_store_requires_packable_layout(%src: tensor<128xf16, #blocked_f16>, %dst: !ttg.memdesc<128xf16, #shared_f16, #smem, mutable>, %bar: !ttg.memdesc<2xi64, #barrier_shared, #smem, mutable>) { + // expected-error @+1 {{requires a layout vectorizing stores to at least 32 bits}} ttng.async_shared_store %src, %dst, %bar : tensor<128xf16, #blocked_f16> -> !ttg.memdesc<128xf16, #shared_f16, #smem, mutable>, !ttg.memdesc<2xi64, #barrier_shared, #smem, mutable> tt.return } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 298a2aa5a5f3..c61607be9c02 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -322,13 +322,28 @@ static Value normalizeAsyncSharedStoreValue(Value value, Type valueTy, return value; } +static Value packAsyncSharedStoreValue(Location loc, ArrayRef values, + Type elemTy, unsigned storeBitwidth, + TritonLLVMOpBuilder &b, + ConversionPatternRewriter &rewriter) { + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); + assert(values.size() == storeBitwidth / elemBitwidth); + if (values.size() == 1) + return normalizeAsyncSharedStoreValue(values.front(), elemTy, storeBitwidth, + b); + + Value packed = packLLVector(loc, values, rewriter); + return b.bitcast(packed, + IntegerType::get(elemTy.getContext(), storeBitwidth)); +} + static Value mapSharedToCluster(Location loc, Value ptr, Value ctaId, ConversionPatternRewriter &rewriter) { auto ptrTy = cast(ptr.getType()); assert(ptrTy.getAddressSpace() == 3 && "st.async.shared expects a shared-memory pointer"); - return NVVM::MapaOp::create(rewriter, loc, - ptr_ty(rewriter.getContext(), 7), ptr, ctaId); + return NVVM::MapaOp::create(rewriter, loc, ptr_ty(rewriter.getContext(), 7), + ptr, ctaId); } static void emitAsyncSharedStore(Location loc, ArrayRef vals, Value dst, @@ -338,14 +353,14 @@ static void emitAsyncSharedStore(Location loc, ArrayRef vals, Value dst, Type elemTy = vecTy.getElementType(); unsigned vec = vecTy.getNumElements(); unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); - assert((elemBitwidth == 32 || elemBitwidth == 64) && - "st.async.shared only supports 32-bit and 64-bit elements"); + assert((elemBitwidth == 8 || elemBitwidth == 16 || elemBitwidth == 32 || + elemBitwidth == 64) && + "st.async.shared only supports packable elements"); if (vec * elemBitwidth > 128) { assert(vec % (128 / elemBitwidth) == 0); int maxVec = 128 / elemBitwidth; - auto splitVecTy = - VectorType::get({static_cast(maxVec)}, elemTy); + auto splitVecTy = VectorType::get({static_cast(maxVec)}, elemTy); for (int i = 0; i < vec / maxVec; i++) { auto newDst = b.gep(dst.getType(), elemTy, dst, b.i32_val(i * maxVec), LLVM::GEPNoWrapFlags::inbounds); @@ -354,29 +369,38 @@ static void emitAsyncSharedStore(Location loc, ArrayRef vals, Value dst, } return; } - assert(1 <= vec && vec <= 4); + assert(vec * elemBitwidth >= 32 && + "st.async.shared requires at least a 32-bit store"); + + unsigned storeBitwidth = elemBitwidth < 32 ? 32u : elemBitwidth; + unsigned elemsPerStore = storeBitwidth / elemBitwidth; + assert(vec % elemsPerStore == 0); + unsigned storeVec = vec / elemsPerStore; + assert(1 <= storeVec && storeVec <= 4); PTXBuilder builder; auto st = builder.create("st.async") ->o("weak") .o("shared::cluster") .o("mbarrier::complete_tx::bytes") - .v(vec, /*predicate=*/vec > 1) - .b(elemBitwidth); + .v(storeVec, /*predicate=*/storeVec > 1) + .b(storeBitwidth); auto *dstOpr = builder.newAddrOperand(dst, "r"); auto *mbarrierOpr = builder.newAddrOperand(mbarrier, "r"); - auto constraint = getAsyncSharedStoreConstraint(elemBitwidth); + auto constraint = getAsyncSharedStoreConstraint(storeBitwidth); PTXBuilder::Operand *valueOpr; - if (vec == 1) { + if (storeVec == 1) { valueOpr = builder.newOperand( - normalizeAsyncSharedStoreValue(vals.front(), elemTy, elemBitwidth, b), + packAsyncSharedStoreValue(loc, vals.slice(0, elemsPerStore), elemTy, + storeBitwidth, b, rewriter), constraint); } else { SmallVector> vecVals; - vecVals.reserve(vec); - for (Value value : vals) { - vecVals.push_back({normalizeAsyncSharedStoreValue( - value, elemTy, elemBitwidth, b), + vecVals.reserve(storeVec); + for (unsigned i = 0; i < storeVec; i++) { + vecVals.push_back({packAsyncSharedStoreValue( + loc, vals.slice(i * elemsPerStore, elemsPerStore), + elemTy, storeBitwidth, b, rewriter), constraint}); } valueOpr = builder.newListOperand(vecVals); @@ -385,11 +409,13 @@ static void emitAsyncSharedStore(Location loc, ArrayRef vals, Value dst, builder.launch(rewriter, loc, void_ty(rewriter.getContext())); } -static void lowerAsyncSharedStore( - Location loc, MLIRContext *ctx, LinearLayout cvt, ArrayRef vals, - Type llvmElemTy, MemDescType dstTy, SharedMemoryObject dstMemObj, - Value mbarrierPtr, ConversionPatternRewriter &rewriter, - const NVIDIA::TargetInfo &targetInfo) { +static void lowerAsyncSharedStore(Location loc, MLIRContext *ctx, + LinearLayout cvt, ArrayRef vals, + Type llvmElemTy, MemDescType dstTy, + SharedMemoryObject dstMemObj, + Value mbarrierPtr, + ConversionPatternRewriter &rewriter, + const NVIDIA::TargetInfo &targetInfo) { auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt); if (!removeBroadcastSrc.isIdentity()) { auto prmtCvt = removeBroadcastSrc.apply(cvt); @@ -420,9 +446,8 @@ static void lowerAsyncSharedStore( Value dst = mapSharedToCluster(storeLoc, shmemAddr, targetCTAId, rewriter); Value mbarrier = mapSharedToCluster(storeLoc, mbarrierPtr, targetCTAId, rewriter); - emitAsyncSharedStore(storeLoc, - values.slice(idx, vecTy.getNumElements()), dst, - mbarrier, vecTy, rewriter); + emitAsyncSharedStore(storeLoc, values.slice(idx, vecTy.getNumElements()), + dst, mbarrier, vecTy, rewriter); return {}; }; auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); @@ -465,8 +490,8 @@ struct AsyncSharedStoreOpConversion : toLinearLayout(dstTy); auto cvt = regLayout.invertAndCompose(sharedLayout); auto values = unpackLLElements(loc, adaptor.getSrc(), rewriter); - lowerAsyncSharedStore(loc, op.getContext(), cvt, values, llvmElemTy, - dstTy, dstMemObj, mbarrierMemObj.getBase(), rewriter, + lowerAsyncSharedStore(loc, op.getContext(), cvt, values, llvmElemTy, dstTy, + dstMemObj, mbarrierMemObj.getBase(), rewriter, targetInfo); rewriter.eraseOp(op); return success(); From c75874c286694fd1a948fcb1af170e22b29bb5c9 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 25 May 2026 09:03:36 -0700 Subject: [PATCH 3/6] Add async shared store ConSan test --- .../Transforms/ConSanNVIDIA.cpp | 4 +- python/src/gluon_ir.cc | 10 ++--- python/test/gluon/test_consan.py | 41 +++++++++++++++++++ python/test/gluon/test_core.py | 4 +- 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index af4052a5a5b2..4ec4e43db70d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -236,8 +236,8 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { info->barriers.push_back( {storeOp.getBarrier(), nullptr, /*count=*/0, MemEffectsOpInfo::BarrierTrackingMode::EffectWrites, - /*txCount=*/-static_cast(tti::getMemDescLength( - storeOp.getDst()))}); + /*txCount=*/ + -static_cast(tti::getMemDescLength(storeOp.getDst()))}); info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, storeOp.getDst()); } diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 0c54444039f1..3a3ee7d1afb1 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -691,11 +691,11 @@ void init_gluon_ir(py::module &&m) { [](GluonOpBuilder &self, Value memDesc, Value value) { self.create(value, memDesc); }) - .def("create_async_shared_store", - [](GluonOpBuilder &self, Value memDesc, Value value, - Value mbarrier) { - self.create(value, memDesc, mbarrier); - }) + .def( + "create_async_shared_store", + [](GluonOpBuilder &self, Value memDesc, Value value, Value mbarrier) { + self.create(value, memDesc, mbarrier); + }) .def("create_local_load", [](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value { return self.create(resultTy, memDesc); diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 265560f86a6f..b017f1ed0d4b 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -234,6 +234,47 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") +@pytest.mark.parametrize("EXPECT_DELTA", [0, 4], ids=["match", "mismatch"]) +def test_async_shared_store_expect_bytes(EXPECT_DELTA, device, run_wrapper, monkeypatch, num_ctas): + if num_ctas == 1: + pytest.skip("st.async.shared requires at least 2 CTAs") + if run_wrapper: + result = run_in_process(test_async_shared_store_expect_bytes, + (EXPECT_DELTA, device, False, monkeypatch, num_ctas)) + if EXPECT_DELTA: + assert_expected_cuda_failure(result.exc) + assert "Deadlock detected" in result.driver_stderr_output + else: + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(out, EXPECT_DELTA: ttgl.constexpr): + cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 1) + layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=cga_layout) + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout) + offsets = ttgl.arange(0, XBLOCK, layout=layout) + values = offsets.to(ttgl.int32) + smem = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK], smem_layout) + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + mbarrier.expect(bar, smem.nbytes_per_cta + EXPECT_DELTA) + hopper.async_store(smem, values, bar) + mbarrier.wait(bar, 0, deps=[smem]) + result = smem.load(layout) + mbarrier.invalidate(bar) + ttgl.store(out + offsets, result) + + output = torch.empty((XBLOCK.value, ), device=device, dtype=torch.int32) + kernel[(1, )](output, EXPECT_DELTA=EXPECT_DELTA, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") @pytest.mark.parametrize("FAILURE", [True, False]) def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas): diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 58ce0f76aa70..cabb9ded7aa7 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -182,7 +182,7 @@ def async_shared_store_f16_kernel(out, BLOCK: ttgl.constexpr): @pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") def test_async_shared_store(): block = 128 - out = torch.empty((block,), device="cuda", dtype=torch.int32) + out = torch.empty((block, ), device="cuda", dtype=torch.int32) compiled = async_shared_store_kernel[(1, )](out, block, num_warps=4, num_ctas=2) @@ -193,7 +193,7 @@ def test_async_shared_store(): @pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") def test_async_shared_store_packed_f16(): block = 256 - out = torch.empty((block,), device="cuda", dtype=torch.float16) + out = torch.empty((block, ), device="cuda", dtype=torch.float16) compiled = async_shared_store_f16_kernel[(1, )](out, block, num_warps=4, num_ctas=2) From 45f1d628f425ab63e581b10175c404c4ba32955e Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 27 May 2026 09:09:46 -0700 Subject: [PATCH 4/6] Address async shared store review comments --- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 7 ------- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 20 +++++++++++++++++++ .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 11 ++++++---- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 035eaf9e1864..8ff7977c309f 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -259,9 +259,6 @@ Type ArriveBarrierOp::getPredicateOperandTypeLike() { return IntegerType::get(getContext(), 1); } -static LogicalResult verifyCompletionBarrierLayout(Operation *op, - Value barrier); - // -- AsyncSharedStoreOp -- LogicalResult AsyncSharedStoreOp::verify() { // PTX defines weak shared::cluster st.async as UB for a one-CTA cluster. @@ -274,14 +271,10 @@ LogicalResult AsyncSharedStoreOp::verify() { return failure(); if (failed(verifyBarrierType(*this, getMbarrier().getType()))) return failure(); - if (failed(verifyCompletionBarrierLayout(getOperation(), getMbarrier()))) - return failure(); auto srcTy = getSrc().getType(); auto dstTy = getDst().getType(); unsigned bitwidth = getIntOrFloatOrPtrBitWidth(srcTy.getElementType()); - if (bitwidth < 8 || bitwidth > 64 || !llvm::isPowerOf2_32(bitwidth)) - return emitOpError("requires 8-, 16-, 32-, or 64-bit element types"); auto regLayout = toLinearLayout(srcTy); auto sharedLayout = isPaddedEncoding(dstTy.getEncoding()) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index c516e76ecfe2..d9ea16d5e5ef 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -164,6 +164,26 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true} { + // CHECK-LABEL: async_shared_store_two_cta_barrier + // CHECK: nvvm.mapa + // CHECK: nvvm.mapa + // CHECK: llvm.ptrtoint + // CHECK: llvm.and + // CHECK: llvm.inttoptr + // CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 + tt.func @async_shared_store_two_cta_barrier(%src: tensor<128xi32, #blocked>, %dst: !ttg.memdesc<128xi32, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<1xi64, #shared0, #smem, mutable>) { + ttng.async_shared_store %src, %dst, %mbarrier : tensor<128xi32, #blocked> -> !ttg.memdesc<128xi32, #shared1, #smem, mutable>, !ttg.memdesc<1xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> #shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index c61607be9c02..b9312c1930d9 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -413,7 +413,7 @@ static void lowerAsyncSharedStore(Location loc, MLIRContext *ctx, LinearLayout cvt, ArrayRef vals, Type llvmElemTy, MemDescType dstTy, SharedMemoryObject dstMemObj, - Value mbarrierPtr, + Value mbarrierPtr, MemDescType mbarrierTy, ConversionPatternRewriter &rewriter, const NVIDIA::TargetInfo &targetInfo) { auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt); @@ -421,7 +421,8 @@ static void lowerAsyncSharedStore(Location loc, MLIRContext *ctx, auto prmtCvt = removeBroadcastSrc.apply(cvt); auto inVals = removeBroadcastSrc.apply(to_vector(vals)); lowerAsyncSharedStore(loc, ctx, prmtCvt, inVals, llvmElemTy, dstTy, - dstMemObj, mbarrierPtr, rewriter, targetInfo); + dstMemObj, mbarrierPtr, mbarrierTy, rewriter, + targetInfo); return; } @@ -446,6 +447,8 @@ static void lowerAsyncSharedStore(Location loc, MLIRContext *ctx, Value dst = mapSharedToCluster(storeLoc, shmemAddr, targetCTAId, rewriter); Value mbarrier = mapSharedToCluster(storeLoc, mbarrierPtr, targetCTAId, rewriter); + mbarrier = LLVM::NVIDIA::getLeaderAddress(storeLoc, rewriter, mbarrier, + mbarrierTy); emitAsyncSharedStore(storeLoc, values.slice(idx, vecTy.getNumElements()), dst, mbarrier, vecTy, rewriter); return {}; @@ -491,8 +494,8 @@ struct AsyncSharedStoreOpConversion auto cvt = regLayout.invertAndCompose(sharedLayout); auto values = unpackLLElements(loc, adaptor.getSrc(), rewriter); lowerAsyncSharedStore(loc, op.getContext(), cvt, values, llvmElemTy, dstTy, - dstMemObj, mbarrierMemObj.getBase(), rewriter, - targetInfo); + dstMemObj, mbarrierMemObj.getBase(), mbarrierTy, + rewriter, targetInfo); rewriter.eraseOp(op); return success(); } From b4ea49e94a8c90ac75b54884807c425b8aecf254 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 29 May 2026 06:31:57 -0700 Subject: [PATCH 5/6] Address Gluon async store review comments --- .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 1 + .../triton/experimental/gluon/language/_semantic.py | 10 ---------- .../gluon/language/nvidia/hopper/__init__.py | 13 ++++++++++++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 552aa3604817..c4d6c1d1d23d 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -416,6 +416,7 @@ def TTNG_AsyncSharedStoreOp : TTNG_Op<"async_shared_store", [ let description = [{ Store a distributed tensor into shared memory using PTX st.async.shared. The store completion decrements the transaction count of `mbarrier`. + This op requires a CTA cluster with at least two CTAs. }]; let arguments = (ins diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index 3f7f87d6c652..902da6081771 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -276,16 +276,6 @@ def shared_store(self, mem_desc, value): lambda: f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match") self.builder.create_local_store(mem_desc.handle, value.handle) - def async_shared_store(self, mem_desc, value, mbarrier): - _check(isinstance(value, ttgl.tensor), lambda: f"expected 'value' to be a tensor, but got a {type(value)}") - _check(isinstance(mbarrier, ttgl.shared_memory_descriptor), - lambda: f"expected 'mbarrier' to be a shared_memory_descriptor, but got a {type(mbarrier)}") - _check(value.shape == mem_desc.shape, - lambda: f"source shape {value.shape} and destination shape {mem_desc.shape} must match") - _check(value.dtype == mem_desc.dtype, - lambda: f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match") - self.builder.create_async_shared_store(mem_desc.handle, value.handle, mbarrier.handle) - def _check_int_indices_and_normalize_axis(self, mem_desc, indices, axis): _check(isinstance(indices, ttgl.tensor), lambda: f"expected 'indices' to be a tensor, but got a {type(indices)}") diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py b/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py index c6171f7dab40..a76c58c9cb9f 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py @@ -21,6 +21,11 @@ ] +def _check(cond, msg_fn, category=ValueError): + if not cond: + raise category(msg_fn()) + + @_core.builtin def fence_async_shared(cluster=False, _semantic=None): """ @@ -37,13 +42,19 @@ def fence_async_shared(cluster=False, _semantic=None): def async_store(dst, value, mbarrier, _semantic=None): """ Store a tensor to shared memory asynchronously and signal an mbarrier on completion. + Requires a CTA cluster with at least two CTAs. Args: dst (shared_memory_descriptor): Destination shared memory descriptor. value (tensor): Tensor whose contents to store. mbarrier (shared_memory_descriptor): Barrier signaled when the store completes. """ - _semantic.async_shared_store(dst, value, mbarrier) + _check(isinstance(value, _core.tensor), lambda: f"expected 'value' to be a tensor, but got a {type(value)}") + _check(isinstance(mbarrier, _core.shared_memory_descriptor), + lambda: f"expected 'mbarrier' to be a shared_memory_descriptor, but got a {type(mbarrier)}") + _check(value.shape == dst.shape, lambda: f"source shape {value.shape} and destination shape {dst.shape} must match") + _check(value.dtype == dst.dtype, lambda: f"source dtype {value.dtype} and destination dtype {dst.dtype} must match") + _semantic.builder.create_async_shared_store(dst.handle, value.handle, mbarrier.handle) class warpgroup_mma_accumulator_type(_core.base_type): From d621f10875f77b403991308920734121f2dda2fe Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 29 May 2026 06:35:47 -0700 Subject: [PATCH 6/6] Drop async shared store PTX version check --- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 2 +- .../nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index d9ea16d5e5ef..925bde8523c7 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' -reconcile-unrealized-casts | FileCheck %s +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 -reconcile-unrealized-casts | FileCheck %s #shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index b9312c1930d9..bd5ec285fd3b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -472,9 +472,8 @@ struct AsyncSharedStoreOpConversion matchAndRewrite(triton::nvidia_gpu::AsyncSharedStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!triton::nvidia_gpu::AsyncSharedStoreOp::isSupported( - targetInfo.getComputeCapability()) || - targetInfo.getPtxVersion() < 81) - return op.emitError("requires cluster-capable SM90+ and PTX 8.1+"); + targetInfo.getComputeCapability())) + return op.emitError("requires cluster-capable SM90+"); auto loc = op.getLoc(); MemDescType dstTy = op.getDst().getType();