diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 006361d9fb47..c4d6c1d1d23d 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -409,6 +409,37 @@ def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive", [ } +def TTNG_AsyncSharedStoreOp : TTNG_Op<"async_shared_store", [ + DeclareOpInterfaceMethods]> { + let summary = "store a distributed tensor to shared memory asynchronously"; + + let description = [{ + Store a distributed tensor into shared memory using PTX st.async.shared. + The store completion decrements the transaction count of `mbarrier`. + This op requires a CTA cluster with at least two CTAs. + }]; + + let arguments = (ins + TT_Tensor:$src, + Arg]>:$dst, + Arg]>:$mbarrier + ); + + let assemblyFormat = [{ + $src `,` $dst `,` $mbarrier attr-dict `:` type($src) `->` + qualified(type($dst)) `,` qualified(type($mbarrier)) + }]; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return ::mlir::triton::nvidia_gpu::TargetFeatures(computeCapability) + .supportClusterOps(); + } + }]; +} + + def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, TMALoadLikeOpInterface]> { diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 7d80e0427ed1..8ff7977c309f 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -27,6 +27,7 @@ #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" @@ -39,6 +40,7 @@ #include "triton/Tools/StrUtil.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" using namespace mlir::triton::gpu; @@ -257,6 +259,44 @@ Type ArriveBarrierOp::getPredicateOperandTypeLike() { return IntegerType::get(getContext(), 1); } +// -- AsyncSharedStoreOp -- +LogicalResult AsyncSharedStoreOp::verify() { + // PTX defines weak shared::cluster st.async as UB for a one-CTA cluster. + if (gpu::lookupNumCTAs(getOperation()) < 2) + return emitOpError("requires at least two CTAs in the cluster"); + if (!getDst().getType().getMutableMemory()) + return emitOpError("cannot store into immutable memory"); + if (failed(triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyBarrierType(*this, getMbarrier().getType()))) + return failure(); + + auto srcTy = getSrc().getType(); + auto dstTy = getDst().getType(); + unsigned bitwidth = getIntOrFloatOrPtrBitWidth(srcTy.getElementType()); + + auto regLayout = toLinearLayout(srcTy); + auto sharedLayout = isPaddedEncoding(dstTy.getEncoding()) + ? paddedLinearLayout(dstTy) + : toLinearLayout(dstTy); + auto cvt = regLayout.invertAndCompose(sharedLayout); + std::optional maybeMaxVecElems; + if (isPaddedEncoding(dstTy.getEncoding())) + maybeMaxVecElems = getMinInterval(dstTy.getEncoding()); + auto vectorization = + largestVectorisation(getContext(), cvt, bitwidth, maybeMaxVecElems); + unsigned elemsPerVec = vectorization.first; + if (elemsPerVec * bitwidth < 32) + return emitOpError("requires a layout vectorizing stores to at least 32 " + "bits"); + return success(); +} + +TypedValue AsyncSharedStoreOp::getBarrier() { + return getMbarrier(); +} + // -- FenceMBarrierInitReleaseClusterOp -- LogicalResult FenceMBarrierInitReleaseClusterOp::verify() { int numCTAs = triton::gpu::lookupNumCTAs(getOperation()); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index e8c6080fd710..4ec4e43db70d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -230,6 +230,17 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, loadOp.getResult()); } + if (auto storeOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->barriers.push_back( + {storeOp.getBarrier(), nullptr, /*count=*/0, + MemEffectsOpInfo::BarrierTrackingMode::EffectWrites, + /*txCount=*/ + -static_cast(tti::getMemDescLength(storeOp.getDst()))}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + storeOp.getDst()); + } if (auto tryCancelOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 4d98119243a2..3a3ee7d1afb1 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -691,6 +691,11 @@ void init_gluon_ir(py::module &&m) { [](GluonOpBuilder &self, Value memDesc, Value value) { self.create(value, memDesc); }) + .def( + "create_async_shared_store", + [](GluonOpBuilder &self, Value memDesc, Value value, Value mbarrier) { + self.create(value, memDesc, mbarrier); + }) .def("create_local_load", [](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value { return self.create(resultTy, memDesc); diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 265560f86a6f..b017f1ed0d4b 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -234,6 +234,47 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") +@pytest.mark.parametrize("EXPECT_DELTA", [0, 4], ids=["match", "mismatch"]) +def test_async_shared_store_expect_bytes(EXPECT_DELTA, device, run_wrapper, monkeypatch, num_ctas): + if num_ctas == 1: + pytest.skip("st.async.shared requires at least 2 CTAs") + if run_wrapper: + result = run_in_process(test_async_shared_store_expect_bytes, + (EXPECT_DELTA, device, False, monkeypatch, num_ctas)) + if EXPECT_DELTA: + assert_expected_cuda_failure(result.exc) + assert "Deadlock detected" in result.driver_stderr_output + else: + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(out, EXPECT_DELTA: ttgl.constexpr): + cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 1) + layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=cga_layout) + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout) + offsets = ttgl.arange(0, XBLOCK, layout=layout) + values = offsets.to(ttgl.int32) + smem = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK], smem_layout) + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + mbarrier.expect(bar, smem.nbytes_per_cta + EXPECT_DELTA) + hopper.async_store(smem, values, bar) + mbarrier.wait(bar, 0, deps=[smem]) + result = smem.load(layout) + mbarrier.invalidate(bar) + ttgl.store(out + offsets, result) + + output = torch.empty((XBLOCK.value, ), device=device, dtype=torch.int32) + kernel[(1, )](output, EXPECT_DELTA=EXPECT_DELTA, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") @pytest.mark.parametrize("FAILURE", [True, False]) def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas): diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 2cec543e4ceb..cabb9ded7aa7 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -143,6 +143,64 @@ def test_local_store_transposed_cga_to_non_transposed_alloc(): torch.testing.assert_close(out, inp.T, atol=0, rtol=0) +@gluon.jit +def async_shared_store_kernel(out, BLOCK: ttgl.constexpr): + layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=[[0]]) + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]]) + + offsets = ttgl.arange(0, BLOCK, layout=layout) + values = offsets.to(ttgl.int32) + smem = ttgl.allocate_shared_memory(ttgl.int32, [BLOCK], shared_layout) + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + mbarrier.expect(bar, smem.nbytes_per_cta) + hopper.async_store(smem, values, bar) + mbarrier.wait(bar, phase=0, deps=[smem]) + result = smem.load(layout) + mbarrier.invalidate(bar) + ttgl.store(out + offsets, result) + + +@gluon.jit +def async_shared_store_f16_kernel(out, BLOCK: ttgl.constexpr): + layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0], cga_layout=[[0]]) + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]]) + + offsets = ttgl.arange(0, BLOCK, layout=layout) + values = offsets.to(ttgl.float16) + smem = ttgl.allocate_shared_memory(ttgl.float16, [BLOCK], shared_layout) + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + mbarrier.expect(bar, smem.nbytes_per_cta) + hopper.async_store(smem, values, bar) + mbarrier.wait(bar, phase=0, deps=[smem]) + result = smem.load(layout) + mbarrier.invalidate(bar) + ttgl.store(out + offsets, result) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_async_shared_store(): + block = 128 + out = torch.empty((block, ), device="cuda", dtype=torch.int32) + + compiled = async_shared_store_kernel[(1, )](out, block, num_warps=4, num_ctas=2) + + assert "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes" in compiled.asm["ptx"] + torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.int32)) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_async_shared_store_packed_f16(): + block = 256 + out = torch.empty((block, ), device="cuda", dtype=torch.float16) + + compiled = async_shared_store_f16_kernel[(1, )](out, block, num_warps=4, num_ctas=2) + + assert "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32" in compiled.asm["ptx"] + torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.float16)) + + @gluon.jit def tma_kernel(desc): layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 1144802fabd6..c8575cab5aa1 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -635,6 +635,22 @@ def test_mbarrier(target): """) +@gluon.jit +def async_shared_store_kernel(): + layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=[[0]]) + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]]) + values = ttgl.arange(0, 128, layout=layout).to(ttgl.int32) + dst = ttgl.allocate_shared_memory(ttgl.int32, [128], shared_layout) + bar = mbarrier.allocate_mbarrier() + hopper.async_store(dst, values, bar) + + +@pytest.mark.parametrize("target", [HOPPER_TARGET, BLACKWELL_TARGET]) +def test_async_shared_store(target): + mod = run_parser(async_shared_store_kernel, *make_args(num_ctas=2), target=target) + assert "ttng.async_shared_store" in anonymize_ir(mod.str_nodebug()) + + @gluon.jit def tcgen05_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr): a = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout) diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index 3da1220133ee..da9d4fd7552a 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -9,7 +9,7 @@ from . import tma from . import clc -from ..hopper import fence_async_shared, mbarrier +from ..hopper import async_store, fence_async_shared, mbarrier from ..ampere import async_copy, mma_v2 from triton._C.libtriton import ir @@ -21,6 +21,7 @@ __all__ = [ "allocate_tensor_memory", "async_copy", + "async_store", "clc", "fence_async_shared", "mbarrier", diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py b/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py index b64216792530..a76c58c9cb9f 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "async_copy", + "async_store", "cluster", "fence_async_shared", "mbarrier", @@ -20,6 +21,11 @@ ] +def _check(cond, msg_fn, category=ValueError): + if not cond: + raise category(msg_fn()) + + @_core.builtin def fence_async_shared(cluster=False, _semantic=None): """ @@ -32,6 +38,25 @@ def fence_async_shared(cluster=False, _semantic=None): _semantic.builder.create_fence_async_shared(cluster) +@_core.builtin +def async_store(dst, value, mbarrier, _semantic=None): + """ + Store a tensor to shared memory asynchronously and signal an mbarrier on completion. + Requires a CTA cluster with at least two CTAs. + + Args: + dst (shared_memory_descriptor): Destination shared memory descriptor. + value (tensor): Tensor whose contents to store. + mbarrier (shared_memory_descriptor): Barrier signaled when the store completes. + """ + _check(isinstance(value, _core.tensor), lambda: f"expected 'value' to be a tensor, but got a {type(value)}") + _check(isinstance(mbarrier, _core.shared_memory_descriptor), + lambda: f"expected 'mbarrier' to be a shared_memory_descriptor, but got a {type(mbarrier)}") + _check(value.shape == dst.shape, lambda: f"source shape {value.shape} and destination shape {dst.shape} must match") + _check(value.dtype == dst.dtype, lambda: f"source dtype {value.dtype} and destination dtype {dst.dtype} must match") + _semantic.builder.create_async_shared_store(dst.handle, value.handle, mbarrier.handle) + + class warpgroup_mma_accumulator_type(_core.base_type): tensor_type: _core.dtype diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index ccf2064eba81..925bde8523c7 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -147,6 +147,59 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: async_shared_store + // CHECK: nvvm.mapa + // CHECK: nvvm.mapa + // CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 + tt.func @async_shared_store(%src: tensor<128xi32, #blocked>, %dst: !ttg.memdesc<128xi32, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<2xi64, #shared0, #smem, mutable>) { + ttng.async_shared_store %src, %dst, %mbarrier : tensor<128xi32, #blocked> -> !ttg.memdesc<128xi32, #shared1, #smem, mutable>, !ttg.memdesc<2xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true} { + // CHECK-LABEL: async_shared_store_two_cta_barrier + // CHECK: nvvm.mapa + // CHECK: nvvm.mapa + // CHECK: llvm.ptrtoint + // CHECK: llvm.and + // CHECK: llvm.inttoptr + // CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 + tt.func @async_shared_store_two_cta_barrier(%src: tensor<128xi32, #blocked>, %dst: !ttg.memdesc<128xi32, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<1xi64, #shared0, #smem, mutable>) { + ttng.async_shared_store %src, %dst, %mbarrier : tensor<128xi32, #blocked> -> !ttg.memdesc<128xi32, #shared1, #smem, mutable>, !ttg.memdesc<1xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: async_shared_store_f16 + // CHECK: llvm.bitcast {{.*}} : vector<2xf16> to i32 + // CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 + tt.func @async_shared_store_f16(%src: tensor<256xf16, #blocked>, %dst: !ttg.memdesc<256xf16, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<2xi64, #shared0, #smem, mutable>) { + ttng.async_shared_store %src, %dst, %mbarrier : tensor<256xf16, #blocked> -> !ttg.memdesc<256xf16, #shared1, #smem, mutable>, !ttg.memdesc<2xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + // TMA copy with barrier mask zero: barrier has no CGALayout -> shared::cta #shared0_cta = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}> diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index 67813f71b49d..4c4b8f0b3d7d 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -38,6 +38,33 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- +#blocked_i32 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared_i32 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_shared_store_requires_cluster(%src: tensor<128xi32, #blocked_i32>, %dst: !ttg.memdesc<128xi32, #shared_i32, #smem, mutable>, %bar: !ttg.memdesc<1xi64, #shared_i32, #smem, mutable>) { + // expected-error @+1 {{requires at least two CTAs in the cluster}} + ttng.async_shared_store %src, %dst, %bar : tensor<128xi32, #blocked_i32> -> !ttg.memdesc<128xi32, #shared_i32, #smem, mutable>, !ttg.memdesc<1xi64, #shared_i32, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked_f16 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#shared_f16 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrier_shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_shared_store_requires_packable_layout(%src: tensor<128xf16, #blocked_f16>, %dst: !ttg.memdesc<128xf16, #shared_f16, #smem, mutable>, %bar: !ttg.memdesc<2xi64, #barrier_shared, #smem, mutable>) { + // expected-error @+1 {{requires a layout vectorizing stores to at least 32 bits}} + ttng.async_shared_store %src, %dst, %bar : tensor<128xf16, #blocked_f16> -> !ttg.memdesc<128xf16, #shared_f16, #smem, mutable>, !ttg.memdesc<2xi64, #barrier_shared, #smem, mutable> + tt.return + } +} + +// ----- + #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #scales = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> #tmem = #ttng.tensor_memory_scales_encoding<> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 5ab129d32b37..bd5ec285fd3b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -300,6 +300,209 @@ struct LocalStoreOpConversion const NVIDIA::TargetInfo &targetInfo; }; +static std::string getAsyncSharedStoreConstraint(unsigned bitwidth) { + switch (bitwidth) { + case 32: + return "r"; + case 64: + return "l"; + default: + llvm_unreachable("unsupported st.async.shared bitwidth"); + } +} + +static Value normalizeAsyncSharedStoreValue(Value value, Type valueTy, + unsigned bitwidth, + TritonLLVMOpBuilder &b) { + Type intTy = IntegerType::get(valueTy.getContext(), bitwidth); + if (isa(valueTy)) + return b.ptrtoint(intTy, value); + if (!valueTy.isInteger()) + return b.bitcast(value, intTy); + return value; +} + +static Value packAsyncSharedStoreValue(Location loc, ArrayRef values, + Type elemTy, unsigned storeBitwidth, + TritonLLVMOpBuilder &b, + ConversionPatternRewriter &rewriter) { + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); + assert(values.size() == storeBitwidth / elemBitwidth); + if (values.size() == 1) + return normalizeAsyncSharedStoreValue(values.front(), elemTy, storeBitwidth, + b); + + Value packed = packLLVector(loc, values, rewriter); + return b.bitcast(packed, + IntegerType::get(elemTy.getContext(), storeBitwidth)); +} + +static Value mapSharedToCluster(Location loc, Value ptr, Value ctaId, + ConversionPatternRewriter &rewriter) { + auto ptrTy = cast(ptr.getType()); + assert(ptrTy.getAddressSpace() == 3 && + "st.async.shared expects a shared-memory pointer"); + return NVVM::MapaOp::create(rewriter, loc, ptr_ty(rewriter.getContext(), 7), + ptr, ctaId); +} + +static void emitAsyncSharedStore(Location loc, ArrayRef vals, Value dst, + Value mbarrier, VectorType vecTy, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Type elemTy = vecTy.getElementType(); + unsigned vec = vecTy.getNumElements(); + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); + assert((elemBitwidth == 8 || elemBitwidth == 16 || elemBitwidth == 32 || + elemBitwidth == 64) && + "st.async.shared only supports packable elements"); + + if (vec * elemBitwidth > 128) { + assert(vec % (128 / elemBitwidth) == 0); + int maxVec = 128 / elemBitwidth; + auto splitVecTy = VectorType::get({static_cast(maxVec)}, elemTy); + for (int i = 0; i < vec / maxVec; i++) { + auto newDst = b.gep(dst.getType(), elemTy, dst, b.i32_val(i * maxVec), + LLVM::GEPNoWrapFlags::inbounds); + emitAsyncSharedStore(loc, vals.slice(i * maxVec, maxVec), newDst, + mbarrier, splitVecTy, rewriter); + } + return; + } + assert(vec * elemBitwidth >= 32 && + "st.async.shared requires at least a 32-bit store"); + + unsigned storeBitwidth = elemBitwidth < 32 ? 32u : elemBitwidth; + unsigned elemsPerStore = storeBitwidth / elemBitwidth; + assert(vec % elemsPerStore == 0); + unsigned storeVec = vec / elemsPerStore; + assert(1 <= storeVec && storeVec <= 4); + + PTXBuilder builder; + auto st = builder.create("st.async") + ->o("weak") + .o("shared::cluster") + .o("mbarrier::complete_tx::bytes") + .v(storeVec, /*predicate=*/storeVec > 1) + .b(storeBitwidth); + auto *dstOpr = builder.newAddrOperand(dst, "r"); + auto *mbarrierOpr = builder.newAddrOperand(mbarrier, "r"); + auto constraint = getAsyncSharedStoreConstraint(storeBitwidth); + PTXBuilder::Operand *valueOpr; + if (storeVec == 1) { + valueOpr = builder.newOperand( + packAsyncSharedStoreValue(loc, vals.slice(0, elemsPerStore), elemTy, + storeBitwidth, b, rewriter), + constraint); + } else { + SmallVector> vecVals; + vecVals.reserve(storeVec); + for (unsigned i = 0; i < storeVec; i++) { + vecVals.push_back({packAsyncSharedStoreValue( + loc, vals.slice(i * elemsPerStore, elemsPerStore), + elemTy, storeBitwidth, b, rewriter), + constraint}); + } + valueOpr = builder.newListOperand(vecVals); + } + st(dstOpr, valueOpr, mbarrierOpr); + builder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +static void lowerAsyncSharedStore(Location loc, MLIRContext *ctx, + LinearLayout cvt, ArrayRef vals, + Type llvmElemTy, MemDescType dstTy, + SharedMemoryObject dstMemObj, + Value mbarrierPtr, MemDescType mbarrierTy, + ConversionPatternRewriter &rewriter, + const NVIDIA::TargetInfo &targetInfo) { + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt); + if (!removeBroadcastSrc.isIdentity()) { + auto prmtCvt = removeBroadcastSrc.apply(cvt); + auto inVals = removeBroadcastSrc.apply(to_vector(vals)); + lowerAsyncSharedStore(loc, ctx, prmtCvt, inVals, llvmElemTy, dstTy, + dstMemObj, mbarrierPtr, mbarrierTy, rewriter, + targetInfo); + return; + } + + auto affineOffset = dstMemObj.getShmemOffset(loc, rewriter, dstTy); + auto maskSpanAffineOffset = dstMemObj.getMaskSpanOffsets(dstTy); + std::optional maybeMaxVecElems; + SmallVector> paddingShifts; + if (triton::gpu::isPaddedEncoding(dstTy.getEncoding())) { + maybeMaxVecElems = triton::gpu::getMinInterval(dstTy.getEncoding()); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + paddingShifts = getPaddedSharedShifts(dstTy.getEncoding(), bitwidth, + /*offsetInBytes=*/true); + } + + SmallVector smemBases(dstMemObj.getBases().begin(), + dstMemObj.getBases().end()); + Value currentCTAId = targetInfo.getClusterCTAId(rewriter, loc); + auto emitSt = [&](RewriterBase &, Location storeLoc, ArrayRef values, + Value shmemAddr, int idx, VectorType vecTy, + std::optional ctaId) -> SmallVector { + Value targetCTAId = ctaId.value_or(currentCTAId); + Value dst = mapSharedToCluster(storeLoc, shmemAddr, targetCTAId, rewriter); + Value mbarrier = + mapSharedToCluster(storeLoc, mbarrierPtr, targetCTAId, rewriter); + mbarrier = LLVM::NVIDIA::getLeaderAddress(storeLoc, rewriter, mbarrier, + mbarrierTy); + emitAsyncSharedStore(storeLoc, values.slice(idx, vecTy.getNumElements()), + dst, mbarrier, vecTy, rewriter); + return {}; + }; + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + lowerLdSt(loc, ctx, cvt, vals, llvmElemTy, smemBases, paddingShifts, + affineOffset, maskSpanAffineOffset, laneId, warpId, rewriter, + targetInfo, maybeMaxVecElems, emitSt); +} + +struct AsyncSharedStoreOpConversion + : public ConvertOpToLLVMPattern { + AsyncSharedStoreOpConversion(const LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern( + converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::AsyncSharedStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!triton::nvidia_gpu::AsyncSharedStoreOp::isSupported( + targetInfo.getComputeCapability())) + return op.emitError("requires cluster-capable SM90+"); + + auto loc = op.getLoc(); + MemDescType dstTy = op.getDst().getType(); + RankedTensorType srcTy = op.getSrc().getType(); + Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); + auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getDst(), llvmElemTy, rewriter); + auto mbarrierTy = op.getMbarrier().getType(); + auto mbarrierMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getMbarrier(), + typeConverter->convertType(mbarrierTy.getElementType()), rewriter); + + auto regLayout = toLinearLayout(srcTy); + auto sharedLayout = isPaddedEncoding(dstTy.getEncoding()) + ? paddedLinearLayout(dstTy) + : toLinearLayout(dstTy); + auto cvt = regLayout.invertAndCompose(sharedLayout); + auto values = unpackLLElements(loc, adaptor.getSrc(), rewriter); + lowerAsyncSharedStore(loc, op.getContext(), cvt, values, llvmElemTy, dstTy, + dstMemObj, mbarrierMemObj.getBase(), mbarrierTy, + rewriter, targetInfo); + rewriter.eraseOp(op); + return success(); + } + +private: + const NVIDIA::TargetInfo &targetInfo; +}; + struct LocalAtomicScatterRMWOpConversion : public ConvertOpToLLVMPattern { public: @@ -378,6 +581,8 @@ void mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns( benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo, benefit.getBenefit() + 1); + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo, benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo,