From 6a54bcc278ac4da97dfcf283873e3f1f2385f603 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 27 Mar 2026 01:00:25 +0100 Subject: [PATCH 1/3] [BACKEND] Implement multiCTA support for TMA gather We also add tighter invariants for the cga_layout part of all TMA ops --- include/triton/Dialect/Triton/IR/Dialect.h | 3 + .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 3 +- lib/Dialect/Triton/IR/Ops.cpp | 6 +- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 64 +++++++- .../Transforms/ClusterBarrierInsertion.cpp | 5 + python/src/gluon_ir.cc | 15 +- python/test/gluon/test_core.py | 145 ++++++++++++++++-- .../gluon/language/nvidia/blackwell/tma.py | 17 +- test/Conversion/tma_multicast_to_llvm.mlir | 27 ++++ .../tritongpu_to_llvm_blackwell.mlir | 23 +++ test/Conversion/tritonnvidiagpu_to_llvm.mlir | 36 +++++ .../automatic-warp-specialization.mlir | 3 +- test/TritonGPU/gsan.mlir | 3 +- test/TritonGPU/loop-pipeline-blackwell.mlir | 3 +- test/TritonGPU/loop-pipeline-hopper.mlir | 3 +- test/TritonGPU/pipeline-lower-loop.mlir | 3 +- test/TritonNvidiaGPU/invalid.mlir | 79 ++++++++++ test/TritonNvidiaGPU/ops.mlir | 15 +- test/TritonNvidiaGPU/tma_lowering.mlir | 11 +- .../LoadStoreOpToLLVM.cpp | 95 ++++++++++-- 20 files changed, 496 insertions(+), 63 deletions(-) create mode 100644 test/Conversion/tma_multicast_to_llvm.mlir diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 765e531e7b02..d1e6cbd5b807 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -113,6 +113,9 @@ class DialectVerifyTensorLayoutInterface }; // Descriptor gather and scatter have restrictions on the tile sizes. +LogicalResult verifyGatherScatterResultType(Operation *op, + ShapedType resultType, + ShapedType indicesType); LogicalResult verifyGatherScatterOp(Operation *op, ShapedType blockType, ShapedType resultType, ShapedType indicesType); diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index c8d02167f696..18f275e7f0ca 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -544,7 +544,8 @@ def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather", [ I32:$y_offset, Arg]>:$barrier, Arg]>:$result, - I1:$pred + I1:$pred, + UnitAttr:$multicast ); let assemblyFormat = [{ diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 313fec9b7b2a..de9555739679 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1507,9 +1507,9 @@ LogicalResult GatherOp::inferReturnTypes( } // -- DescriptorGatherOp -static LogicalResult verifyGatherScatterResultType(Operation *op, - ShapedType resultType, - ShapedType indicesType) { +LogicalResult verifyGatherScatterResultType(Operation *op, + ShapedType resultType, + ShapedType indicesType) { if (indicesType.getRank() != 1) return op->emitOpError("x offsets must be a 1D tensor, but got ") << indicesType; diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 62d43f5b3dc6..75fd6c78be7a 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -394,6 +394,55 @@ static LogicalResult verifyAsyncTMAStoreOp(Operation *op, return verifyTMAEncoding(op, desc.getType(), srcEnc); } +static LogicalResult verifyAsyncTMAGatherScatterOp(Operation *op, + ShapedType blockType, + MemDescType memDescType, + ShapedType indicesType) { + if (blockType.getRank() != 2) + return op->emitOpError("descriptor block must be a 2D tensor, but got ") + << blockType; + if (blockType.getShape()[0] != 1) + return op->emitOpError("descriptor block must have exactly 1 row, but got ") + << blockType; + if (failed(verifyGatherScatterResultType(op, memDescType, indicesType))) + return failure(); + + if (memDescType.getShape()[1] != blockType.getShape()[1]) + return op->emitOpError("result tensor number of columns must match block (") + << blockType.getShape()[1] << "), but got " << memDescType; + if (memDescType.getElementType() != blockType.getElementType()) + return op->emitOpError("result tensor element type must match block (") + << blockType.getElementType() << "), but got " << memDescType; + + ArrayRef allocShape = memDescType.getAllocShape(); + if (allocShape.size() < 2 || + memDescType.getShape() != allocShape.take_back(2)) + return op->emitOpError("memdesc shape must match alloc shape"); + + auto xOffsetsType = cast(indicesType); + if (xOffsetsType.getEncoding()) { + auto xCoordsLayout = triton::gpu::toLinearLayout(xOffsetsType); + auto kLane = StringAttr::get(op->getContext(), "lane"); + if (getContigPerThread(xOffsetsType).front() < 4) + return op->emitOpError( + "x offsets must have at least 4 contiguous elements per thread"); + unsigned threadsPerWarp = xCoordsLayout.getInDimSize(kLane); + if (xCoordsLayout.getFreeVariableMasks()[kLane] != (threadsPerWarp - 1)) + return op->emitOpError("x offsets must be broadcasted across each warp"); + auto kBlock = StringAttr::get(op->getContext(), "block"); + auto kDim0 = StringAttr::get(op->getContext(), "dim0"); + auto rowsCGA = getCGALayout(memDescType.getEncoding()) + .getLinearLayout() + .sublayout({kBlock}, {kDim0}); + auto xOffsetsCGA = + getCGALayout(xOffsetsType.getEncoding()).getLinearLayout(); + if (rowsCGA != xOffsetsCGA) + return op->emitOpError( + "x offsets must have the same row CGA layout as the memdesc"); + } + return success(); +} + // Helper to determine if the descriptor type is for im2col mode static bool isIm2ColDescriptor(Type descType) { return isa(descType); @@ -546,9 +595,12 @@ LogicalResult AsyncTMAGatherOp::verify() { // `tile::gather4` does not support fp4_padded operands. if (isFp4Padded(getResult().getType().getEncoding())) return emitOpError("does not support fp4_padded operands"); - return verifyGatherScatterOp(*this, - getDesc().getType().getSignlessBlockType(), - resultType, getXOffsets().getType()); + if (getMulticast() && !hasCGABroadcast(resultType)) + return emitOpError( + "multicast requires the shared layout to broadcast across CTAs"); + return verifyAsyncTMAGatherScatterOp( + *this, getDesc().getType().getSignlessBlockType(), resultType, + getXOffsets().getType()); } Value AsyncTMAGatherOp::getPredicateOperand() { return getPred(); } @@ -566,9 +618,9 @@ LogicalResult AsyncTMAScatterOp::verify() { auto srcType = getSrc().getType(); if (failed(verifyAsyncTMAStoreOp(*this, getDesc(), srcType))) return failure(); - return verifyGatherScatterOp(*this, - getDesc().getType().getSignlessBlockType(), - srcType, getXOffsets().getType()); + return verifyAsyncTMAGatherScatterOp( + *this, getDesc().getType().getSignlessBlockType(), srcType, + getXOffsets().getType()); } // -- TCGen5MMAOp -- diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp index 58c5b734ac6f..afabbf693d2d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp @@ -47,6 +47,8 @@ static bool isDistributedMultiCTAOp(Operation *op, bool isRead) { return ttng::getModuleTwoCTAs(op); } else if (auto tma = dyn_cast(op)) { return tma.getMulticast(); + } else if (auto tma = dyn_cast(op)) { + return tma.getMulticast(); } return false; } @@ -109,6 +111,9 @@ usesTrackedBarrierInCrossCTAConsumerOp(Operation *op, if (auto tma = dyn_cast(op)) { return tma.getMulticast() && aliasesTracked(tma.getBarrier()); } + if (auto tma = dyn_cast(op)) { + return tma.getMulticast() && aliasesTracked(tma.getBarrier()); + } return false; } diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 1bcf493ddbd3..86da5d99da91 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -943,12 +943,15 @@ void init_gluon_ir(py::module &&m) { [](GluonOpBuilder &self, int pendings) { self.create(pendings); }) - .def("create_async_tma_gather", - [](GluonOpBuilder &self, Value descPtr, Value xOffsets, - Value yOffset, Value barrier, Value result, Value pred) { - self.create(descPtr, xOffsets, yOffset, - barrier, result, pred); - }) + .def( + "create_async_tma_gather", + [](GluonOpBuilder &self, Value descPtr, Value xOffsets, Value yOffset, + Value barrier, Value result, Value pred, bool multicast) { + multicast &= + ttng::hasCGABroadcast(cast(result.getType())); + self.create( + descPtr, xOffsets, yOffset, barrier, result, pred, multicast); + }) .def("create_async_tma_scatter", [](GluonOpBuilder &self, Value descPtr, Value xOffsets, Value yOffset, Value src) { diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 169a417fb2fe..ad2b8d39f8f1 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -27,6 +27,7 @@ from triton.experimental.gluon.language.nvidia.ampere import async_copy, mma_v2 from triton.experimental.gluon.language.nvidia.hopper import tma, mbarrier, fence_async_shared from triton.experimental.gluon.language.nvidia import hopper +from triton.experimental.gluon.language.nvidia.blackwell import tma as blackwell_tma from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy from triton.experimental.gluon.language.extra import libdevice from triton.experimental.gluon.language.nvidia.blackwell import ( @@ -262,6 +263,87 @@ def test_tma_multicast_copy(ctas_per_cga): torch.testing.assert_close(out, inp, atol=0, rtol=0) +@gluon.jit +def tma_gather_scatter_kernel(in_desc, gather_out_desc, scatter_out_desc, gather_idx_ptr, scatter_idx_ptr, + BLOCK_M: ttgl.constexpr, x_offsets_layout: ttgl.constexpr): + smem = ttgl.allocate_shared_memory(in_desc.dtype, [BLOCK_M, gather_out_desc.block_shape[1]], gather_out_desc.layout) + + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + + gather_offsets = ttgl.load(gather_idx_ptr + ttgl.arange(0, BLOCK_M, layout=x_offsets_layout)) + mbarrier.expect(bar, blackwell_tma.nbytes_per_cta_gather(in_desc, gather_offsets)) + blackwell_tma.async_gather(in_desc, gather_offsets, 0, bar, smem, multicast=True) + mbarrier.wait(bar, phase=0, deps=[smem]) + + mbarrier.invalidate(bar) + + scatter_offsets = ttgl.load(scatter_idx_ptr + ttgl.arange(0, BLOCK_M, layout=x_offsets_layout)) + tma.async_copy_shared_to_global(gather_out_desc, [0, 0], smem) + blackwell_tma.async_scatter(scatter_out_desc, scatter_offsets, 0, smem) + tma.store_wait(0) + + smem._keep_alive() + + +def get_split_dim(cga_layout, dim): + return 1 << sum(basis[dim] != 0 for basis in cga_layout) + + +@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") +@pytest.mark.parametrize("cga_layout", [ + [[1, 0]], + [[0, 0], [1, 0]], + [[1, 0], [0, 0]], + [[1, 0], [2, 0]], +]) +def test_tma_gather_scatter_multi_cta(cga_layout): + cga_split_num = [get_split_dim(cga_layout, dim) for dim in range(2)] + + BLOCK_M = 32 * cga_split_num[0] + BLOCK_N = 128 * cga_split_num[1] + + inp = torch.arange(BLOCK_M * BLOCK_N, dtype=torch.float16, device="cuda").reshape(BLOCK_M, BLOCK_N) + gather_idx = torch.arange(BLOCK_M - 1, -1, -1, dtype=torch.int32, device="cuda") + scatter_idx = (torch.arange(0, BLOCK_M, dtype=torch.int32, device="cuda") + 1) % BLOCK_M + gather_out = torch.empty_like(inp) + scatter_out = torch.zeros_like(inp) + + layout = ttgl.NVMMASharedLayout.get_default_for( + [BLOCK_M, BLOCK_N], + ttgl.float16, + cga_layout=cga_layout, + ) + in_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(inp, [1, BLOCK_N // cga_split_num[1]], layout) + gather_out_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(gather_out, [BLOCK_M, BLOCK_N], layout) + scatter_out_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(scatter_out, [1, BLOCK_N // cga_split_num[1]], + layout) + + offset_layout = ttgl.BlockedLayout([4, 1], [1, 32], [4, 1], [0, 1], cga_layout=cga_layout) + x_offsets_layout = ttgl.SliceLayout(1, offset_layout) + + num_ctas = 1 << len(cga_layout) + compiled = tma_gather_scatter_kernel[(1, )]( + in_desc, + gather_out_desc, + scatter_out_desc, + gather_idx, + scatter_idx, + BLOCK_M, + x_offsets_layout, + num_warps=4, + num_ctas=num_ctas, + ) + + expected_gather = inp[gather_idx.to(torch.int64)] + expected_scatter = torch.zeros_like(inp) + expected_scatter[scatter_idx.to(torch.int64)] = expected_gather + expect_multicast = any(all(coord == 0 for coord in basis) for basis in cga_layout) + assert (".multicast::cluster" in compiled.asm["ptx"]) == expect_multicast + torch.testing.assert_close(gather_out, expected_gather, atol=0, rtol=0) + torch.testing.assert_close(scatter_out, expected_scatter, atol=0, rtol=0) + + @gluon.jit def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, acc_tmem_layout: ttgl.constexpr, blocked_c: ttgl.constexpr): @@ -655,11 +737,13 @@ def test_warpgroup_mma(ASYNC): @gluon.jit -def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, - BLOCK_K: ttgl.constexpr, NUM_K_TILES: ttgl.constexpr, block_layout_c: ttgl.constexpr, +def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, out_desc, gather_idx_ptr, scatter_idx_ptr, + BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr, + NUM_K_TILES: ttgl.constexpr, block_layout_c: ttgl.constexpr, acc_layout: ttgl.constexpr, acc_tmem_layout: ttgl.constexpr, - use_tcgen05: ttgl.constexpr, multicast: ttgl.constexpr): - smem_a = ttgl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout) + use_tcgen05: ttgl.constexpr, multicast: ttgl.constexpr, + use_gather_scatter: ttgl.constexpr): + smem_a = ttgl.allocate_shared_memory(a_desc.dtype, [BLOCK_M, BLOCK_K], a_desc.layout) smem_b = ttgl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout) two_ctas: ttgl.constexpr = isinstance(acc_tmem_layout, TensorMemoryLayout) and acc_tmem_layout.two_ctas @@ -680,9 +764,19 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, BLOCK_M: ttgl.constexp else: acc = ttgl.zeros([BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=acc_layout) + if use_gather_scatter: + gather_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 1, ttgl.BlockedLayout([4, 1], [1, 32], [ttgl.num_warps(), 1], [0, 1], cga_layout=a_desc.layout.cga_layout)) + gather_offsets = ttgl.load(gather_idx_ptr + ttgl.arange(0, BLOCK_M, layout=gather_offsets_layout)) + for k in range(NUM_K_TILES): - mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, k * BLOCK_K], tma_bar, smem_a, multicast=multicast) + a_nbytes_per_cta: ttgl.constexpr = blackwell_tma.nbytes_per_cta_gather( + a_desc, gather_offsets) if use_gather_scatter else a_desc.nbytes_per_cta + mbarrier.expect(tma_bar, a_nbytes_per_cta + b_desc.nbytes_per_cta) + if use_gather_scatter: + blackwell_tma.async_gather(a_desc, gather_offsets, k * BLOCK_K, tma_bar, smem_a, multicast=multicast) + else: + tma.async_copy_global_to_shared(a_desc, [0, k * BLOCK_K], tma_bar, smem_a, multicast=multicast) tma.async_copy_global_to_shared(b_desc, [k * BLOCK_K, 0], tma_bar, smem_b, multicast=multicast) mbarrier.wait(tma_bar, phase=phase_tma, deps=[smem_a, smem_b]) phase_tma ^= 1 @@ -705,9 +799,19 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, BLOCK_M: ttgl.constexp acc = acc_tmem.load() acc = ttgl.convert_layout(acc, block_layout_c) - offs_m = ttgl.arange(0, BLOCK_M)[:, None] - offs_n = ttgl.arange(0, BLOCK_N)[None, :] - ttgl.store(out_ptr + offs_m * BLOCK_N + offs_n, acc) + if use_gather_scatter: + scatter_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 1, ttgl.BlockedLayout([4, 1], [1, 32], [ttgl.num_warps(), 1], [0, 1], + cga_layout=out_desc.layout.cga_layout)) + scatter_offsets = ttgl.load(scatter_idx_ptr + ttgl.arange(0, BLOCK_M, layout=scatter_offsets_layout)) + acc_smem = ttgl.allocate_shared_memory(out_desc.dtype, [BLOCK_M, BLOCK_N], out_desc.layout, acc) + blackwell_tma.async_scatter(out_desc, scatter_offsets, 0, acc_smem) + tma.store_wait(0) + acc_smem._keep_alive() + else: + offs_m = ttgl.arange(0, BLOCK_M)[:, None] + offs_n = ttgl.arange(0, BLOCK_N)[None, :] + ttgl.store(out_ptr + offs_m * BLOCK_N + offs_n, acc) @pytest.mark.skipif(not (is_hopper() or is_blackwell()), reason="Requires Hopper or Blackwell") @@ -716,7 +820,8 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, BLOCK_M: ttgl.constexp @pytest.mark.parametrize("ctas_per_cga", [[1, 1], [2, 1], [4, 4]]) @pytest.mark.parametrize("two_ctas", [False, True] if is_blackwell() else [False]) @pytest.mark.parametrize("multicast", [False, True]) -def test_tma_mma_shared_inputs(warps, reps, ctas_per_cga, two_ctas, multicast): +@pytest.mark.parametrize("use_gather_scatter", [False, True] if is_blackwell() else [False]) +def test_tma_mma_shared_inputs(warps, reps, ctas_per_cga, two_ctas, multicast, use_gather_scatter): bitwidth = 16 acc_dtype = torch.float32 @@ -788,19 +893,26 @@ def cast(x, dtype): gluon_dtype = ttgl.float16 shared_layout_a = ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gluon_dtype, cga_layout=cga_layout_a) shared_layout_b = ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gluon_dtype, cga_layout=cga_layout_b) + shared_layout_c = ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], ttgl.float32, cga_layout=cga_layout_c) assert shared_layout_a.swizzle_byte_width != 0 assert shared_layout_b.swizzle_byte_width != 0 - a_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(a, [BLOCK_M, BLOCK_K], shared_layout_a) + a_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(a, [1 if use_gather_scatter else BLOCK_M, BLOCK_K], + shared_layout_a) b_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(b, [BLOCK_K, BLOCK_N], shared_layout_b) - num_warps = warps[0] * warps[1] num_ctas = ctas_per_cga[0] * ctas_per_cga[1] + out_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(out, [1, BLOCK_N], shared_layout_c) + gather_idx = torch.arange(BLOCK_M - 1, -1, -1, dtype=torch.int32, device=device) + scatter_idx = (torch.arange(0, BLOCK_M, dtype=torch.int32, device=device) + 1) % BLOCK_M try: tma_mma_shared_inputs_kernel[(1, )]( a_desc, b_desc, out, + out_desc, + gather_idx, + scatter_idx, BLOCK_M, BLOCK_N, BLOCK_K, @@ -810,6 +922,7 @@ def cast(x, dtype): acc_tmem_layout, is_blackwell(), multicast=multicast, + use_gather_scatter=use_gather_scatter, num_warps=num_warps, num_ctas=num_ctas, ) @@ -819,7 +932,13 @@ def cast(x, dtype): try: allow_tf32 = torch.backends.cuda.matmul.allow_tf32 torch.backends.cuda.matmul.allow_tf32 = True - ref = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + if use_gather_scatter: + matmul = torch.matmul(a[gather_idx.to(torch.int64)].to(torch.float32), b.to(torch.float32)) + ref = torch.empty_like(matmul) + # Correct as scatter_idx is a permutation! + ref[scatter_idx.to(torch.int64)] = matmul + else: + ref = torch.matmul(a.to(torch.float32), b.to(torch.float32)) finally: torch.backends.cuda.matmul.allow_tf32 = allow_tf32 diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py index 01adc772005c..3bef1c0f1bf9 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py @@ -1,4 +1,5 @@ import triton.experimental.gluon.language._core as ttgl +from triton.experimental.gluon._runtime import constexpr_function from triton.experimental.gluon.language._core import builtin from triton.experimental.gluon.language.nvidia.hopper.tma import ( async_copy_global_to_shared, @@ -19,11 +20,21 @@ "tensor_descriptor", "tensor_descriptor_type", "make_tensor_descriptor", + "nbytes_per_cta_gather", ] +@constexpr_function +def nbytes_per_cta_gather(desc, offsets): + num_splits = 1 + for basis in offsets.type.layout.parent.cga_layout: + if basis != [0, 0]: + num_splits *= 2 + return offsets.shape[0] * desc.block_shape[1] * (desc.dtype.primitive_bitwidth // 8) // num_splits + + @builtin -def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None): +def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, multicast=False, _semantic=None): """ Asynchronously gather elements from global memory to shared memory using TMA. @@ -34,14 +45,16 @@ def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _ barrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete. result (tensor_memory_descriptor): Result shared memory, must have NVMMASharedLayout. pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True. + multicast (bool): Enable multicast. """ if _semantic.builder.options.enable_iisan: _emit_alignment_check(tensor_desc, (y_offset, ), "async_gather", "y_offset", _semantic=_semantic) pred = _semantic.to_tensor(pred) y_offset = _semantic.to_tensor(y_offset) + multicast = ttgl._unwrap_if_constexpr(multicast) _semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle, - result.handle, pred.handle) + result.handle, pred.handle, multicast) def _emit_scatter_nonnegative_check(x_offsets, y_offset, _semantic=None): diff --git a/test/Conversion/tma_multicast_to_llvm.mlir b/test/Conversion/tma_multicast_to_llvm.mlir new file mode 100644 index 000000000000..4e2e0f91c4ef --- /dev/null +++ b/test/Conversion/tma_multicast_to_llvm.mlir @@ -0,0 +1,27 @@ +// RUN: triton-opt %s --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +#blocked_bcast = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 0]]}> +#shared_bar_bcast = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#shared_gather_bcast = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + +// CHECK-LABEL: @tma_gather_multicast +tt.func public @tma_gather_multicast(%arg0: !tt.tensordesc<1x128xbf16, #shared_gather_bcast>, %arg1: !ttg.memdesc<1xi64, #shared_bar_bcast, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked_bcast}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared_gather_bcast, #smem, mutable>, %arg5: i1) { + // CHECK: [[BAR:%.*]] = extractvalue {{.*}} %1, 0 + // CHECK: [[BAR_INT:%.*]] = ptrtoint ptr addrspace(3) [[BAR]] to i64 + // CHECK: [[LEADER_BAR_INT:%.*]] = and i64 [[BAR_INT]], + // CHECK: [[LEADER_BAR:%.*]] = inttoptr i64 [[LEADER_BAR_INT]] to ptr addrspace(3) + // CHECK: [[ELECT:%.*]] = tail call { i32, i1 } @llvm.nvvm.elect.sync + // CHECK: [[ELECT_PRED:%.*]] = extractvalue { i32, i1 } [[ELECT]], 1 + // CHECK: [[PRED:%.*]] = and i1 {{.*}}, [[ELECT_PRED]] + // CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$1], [$2, {$3, $4, $5, $6, $7}], [$8], $9;", "b,r,l,r,r,r,r,r,r,h" + // CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) {{.*}}, ptr nonnull %0, i32 %3, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr addrspace(3) [[LEADER_BAR]], i32 {{(%[0-9]+|3)}}) + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 {multicast} : !tt.tensordesc<1x128xbf16, #shared_gather_bcast>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked_bcast}>>, i32, !ttg.memdesc<1xi64, #shared_bar_bcast, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared_gather_bcast, #smem, mutable>, i1 + + // CHECK: ret void + tt.return +} + +} diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index 95be9cb20c57..e94d12d4cf35 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -473,7 +473,30 @@ tt.func public @tmem_copy_2d_2cta(%src: !ttg.memdesc<128x32xi8, #shared, #ttg.sh // CHECK-NOT: tcgen05.commit ttng.tmem_copy %src, %dst : !ttg.memdesc<128x32xi8, #shared, #ttg.shared_memory>, !ttg.memdesc<128x32xi8, #tmem_scales, #ttng.tensor_memory, mutable> tt.return + } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 0]]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true} { + // CHECK-LABEL: @tma_scatter_broadcast_two_ctas + // CHECK: %[[CTA:.+]] = nvg.cluster_id + // CHECK: %[[MASK:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[CTA_IN_GROUP:.+]] = llvm.and %[[CTA]], %[[MASK]] : i32 + // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[IS_LEADER:.+]] = llvm.icmp "eq" %[[CTA_IN_GROUP]], %[[ZERO]] : i32 + // CHECK: %[[WARP_PRED:.+]] = llvm.icmp "eq" {{.*}} : i32 + // CHECK: %[[LEADER_WARP_PRED:.+]] = llvm.and %[[IS_LEADER]], %[[WARP_PRED]] : i1 + // CHECK: %[[ELECT:.+]] = nvvm.elect.sync -> i1 + // CHECK: %[[PRED:.+]] = llvm.and %[[LEADER_WARP_PRED]], %[[ELECT]] : i1 + // CHECK: @$0 cp.async.bulk.tensor.2d.tile::scatter4.global.shared::cta.bulk_group + // CHECK-SAME: "b,l,r,r,r,r,r,r" %[[PRED]], + tt.func @tma_scatter_broadcast_two_ctas(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, %y_offset: i32, %src: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>) { + ttng.async_tma_scatter %desc[%x_offsets, %y_offset] %src : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable> + tt.return + } } // ----- diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 654c53e23ecb..348789571b2f 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -184,6 +184,23 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.tw // ----- +#blocked_offsets = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 1]]}> +#shared0_cluster = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#shared1_split = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: tma_gather_barrier_mask_nonzero + // Gather uses shared::cluster when barrier mask is non-zero + // CHECK: cp.async.bulk.tensor.2d.tile::gather4.shared::cluster.global.mbarrier::complete_tx::bytes + // CHECK-NOT: cp.async.bulk.tensor.2d.tile::gather4.shared::cta.global.mbarrier + tt.func @tma_gather_barrier_mask_nonzero(%tma: !tt.tensordesc<1x128xbf16, #shared1_split>, %alloc: !ttg.memdesc<32x128xbf16, #shared1_split, #smem, mutable>, %x_offsets: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked_offsets}>>, %y: i32, %barrier: !ttg.memdesc<1xi64, #shared0_cluster, #smem>, %pred: i1) { + ttng.async_tma_gather %tma[%x_offsets, %y] %alloc, %barrier, %pred : !tt.tensordesc<1x128xbf16, #shared1_split>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked_offsets}>>, i32, !ttg.memdesc<1xi64, #shared0_cluster, #smem>, !ttg.memdesc<32x128xbf16, #shared1_split, #smem, mutable>, i1 + tt.return + } +} + +// ----- + #shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 32}> #smem = #ttg.shared_memory @@ -305,6 +322,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#shared1_broadcast = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CGALayout = [[0, 0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: tma_copy_local_to_global_broadcast + // CHECK: elect.sync + // CHECK: nvg.cluster_id + // CHECK: llvm.and + // CHECK: llvm.icmp "eq" + // CHECK: "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r" {{.*}} : (i1, !llvm.ptr, i32, i32, !llvm.ptr<3>) -> !llvm.void + // CHECK-NOT: shared::cluster + // CHECK: nvvm.cp.async.bulk.commit.group + tt.func @tma_copy_local_to_global_broadcast(%tma: !tt.tensordesc<128x128xf32, #shared1_broadcast>, %alloc: !ttg.memdesc<128x128xf32, #shared1_broadcast, #smem>, %x: i32) { + ttng.async_tma_copy_local_to_global %tma[%x, %x] %alloc : !tt.tensordesc<128x128xf32, #shared1_broadcast>, !ttg.memdesc<128x128xf32, #shared1_broadcast, #smem> + tt.return + } +} + +// ----- + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: async_tma_store_wait // CHECK: nvvm.cp.async.bulk.wait_group 0 {read} diff --git a/test/TritonGPU/automatic-warp-specialization.mlir b/test/TritonGPU/automatic-warp-specialization.mlir index 41e94686a542..5bebd300155f 100644 --- a/test/TritonGPU/automatic-warp-specialization.mlir +++ b/test/TritonGPU/automatic-warp-specialization.mlir @@ -2,7 +2,8 @@ // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization=num-stages=2 -tritongpu-pipeline | FileCheck %s --check-prefix=CHECK --check-prefix=PIPELINE // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization=num-stages=2 -tritongpu-pipeline -tritongpu-optimize-partition-warps | FileCheck %s --check-prefix=OPT -#indices_layout = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#indices_layout_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#indices_layout = #ttg.slice<{dim = 0, parent = #indices_layout_parent}> #acc_layout = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #oper_layout = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> #b_layout = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> diff --git a/test/TritonGPU/gsan.mlir b/test/TritonGPU/gsan.mlir index 4dc84dff6f16..de173896f21a 100644 --- a/test/TritonGPU/gsan.mlir +++ b/test/TritonGPU/gsan.mlir @@ -63,7 +63,8 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} // ----- -#blocked_rows = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked_rows_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked_rows = #ttg.slice<{dim = 0, parent = #blocked_rows_parent}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> #bar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory diff --git a/test/TritonGPU/loop-pipeline-blackwell.mlir b/test/TritonGPU/loop-pipeline-blackwell.mlir index f52158eed27f..023d0285f9bd 100644 --- a/test/TritonGPU/loop-pipeline-blackwell.mlir +++ b/test/TritonGPU/loop-pipeline-blackwell.mlir @@ -168,7 +168,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.targ // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.slice<{dim = 0, parent = #blocked1_parent}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index f556c82e6452..19d37d429ba2 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -472,7 +472,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.slice<{dim = 0, parent = #blocked1_parent}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_scatter_pipeline diff --git a/test/TritonGPU/pipeline-lower-loop.mlir b/test/TritonGPU/pipeline-lower-loop.mlir index a0837216c1fe..8e38f0c4281a 100644 --- a/test/TritonGPU/pipeline-lower-loop.mlir +++ b/test/TritonGPU/pipeline-lower-loop.mlir @@ -779,7 +779,8 @@ tt.func @tma_load_lowering(%lb : index, %ub : index, %step : index, // ----- #A = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#offsets = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#offsets_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#offsets = #ttg.slice<{dim = 0, parent = #offsets_parent}> #nvmma_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index eb7939473ec4..892ecf80dabd 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -654,6 +654,66 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // ----- +#blocked_broadcast = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}> +#nvmma_no_broadcast = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0]]}> +#shared_bar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_tma_gather_multicast_requires_broadcast(%arg0: !tt.tensordesc<1x128xf16, #nvmma_no_broadcast>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %x_offsets = arith.constant dense<0> : tensor<32xi32, #blocked_broadcast> + %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> + %result = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16, #nvmma_no_broadcast, #smem, mutable> + // expected-error @below {{multicast requires the shared layout to broadcast across CTAs}} + ttng.async_tma_gather %arg0[%x_offsets, %c0_i32] %result, %bar, %true {multicast} : !tt.tensordesc<1x128xf16, #nvmma_no_broadcast>, tensor<32xi32, #blocked_broadcast>, i32, !ttg.memdesc<1xi64, #shared_bar, #smem, mutable>, !ttg.memdesc<32x128xf16, #nvmma_no_broadcast, #smem, mutable>, i1 + tt.return + } +} + +// ----- + +#blocked_broadcast_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 0], [0, 0]]}> +#blocked_broadcast = #ttg.slice<{dim = 0, parent = #blocked_broadcast_parent}> +#nvmma_partial_broadcast = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0], [0, 0]]}> +#shared_bar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0], [0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_tma_gather_multicast_requires_matching_x_offset_cga(%arg0: !tt.tensordesc<1x128xf16, #nvmma_partial_broadcast>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %x_offsets = arith.constant dense<0> : tensor<32xi32, #blocked_broadcast> + %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> + %result = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16, #nvmma_partial_broadcast, #smem, mutable> + // expected-error @below {{x offsets must have the same row CGA layout as the memdesc}} + ttng.async_tma_gather %arg0[%x_offsets, %c0_i32] %result, %bar, %true {multicast} : !tt.tensordesc<1x128xf16, #nvmma_partial_broadcast>, tensor<32xi32, #blocked_broadcast>, i32, !ttg.memdesc<1xi64, #shared_bar, #smem, mutable>, !ttg.memdesc<32x128xf16, #nvmma_partial_broadcast, #smem, mutable>, i1 + tt.return + } +} + +// ----- + +#blocked_split_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 1]]}> +#blocked_split = #ttg.slice<{dim = 0, parent = #blocked_split_parent}> +#nvmma_broadcast = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> +#shared_bar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_tma_gather_multicast_requires_uniform_x_offsets(%arg0: !tt.tensordesc<1x128xf16, #nvmma_broadcast>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %x_offsets = arith.constant dense<0> : tensor<32xi32, #blocked_split> + %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> + %result = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16, #nvmma_broadcast, #smem, mutable> + // expected-error @below {{x offsets must have the same row CGA layout as the memdesc}} + ttng.async_tma_gather %arg0[%x_offsets, %c0_i32] %result, %bar, %true {multicast} : !tt.tensordesc<1x128xf16, #nvmma_broadcast>, tensor<32xi32, #blocked_split>, i32, !ttg.memdesc<1xi64, #shared_bar, #smem, mutable>, !ttg.memdesc<32x128xf16, #nvmma_broadcast, #smem, mutable>, i1 + tt.return + } + +} + +// ----- + // Test invalid TensorDescIm2ColType: rank-3 blockType (must be rank-2) module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // expected-error @below {{TensorDescIm2ColType requires rank-2 shape, got rank 3}} @@ -774,6 +834,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared_bar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_tma_gather_requires_legal_x_offsets(%arg0: !tt.tensordesc<1x128xf16, #shared>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %x_offsets = arith.constant dense<0> : tensor<32xi32, #blocked> + %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> + %result = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable> + // expected-error @below {{x offsets must have at least 4 contiguous elements per thread}} + ttng.async_tma_gather %arg0[%x_offsets, %c0_i32] %result, %bar, %true : !tt.tensordesc<1x128xf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared_bar, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared, #smem, mutable>, i1 + tt.return + } +} + +// ----- + #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 64}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { diff --git a/test/TritonNvidiaGPU/ops.mlir b/test/TritonNvidiaGPU/ops.mlir index 84c8e31af7bd..cc3d48f19979 100644 --- a/test/TritonNvidiaGPU/ops.mlir +++ b/test/TritonNvidiaGPU/ops.mlir @@ -7,7 +7,8 @@ #tmem_int32 = #ttng.tensor_memory_encoding #tmem_scales = #ttng.tensor_memory_scales_encoding<> -#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#offsets = #ttg.slice<{dim = 0, parent = #blocked}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1]}> #scales = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> @@ -67,12 +68,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-SAME: [[BAR:%arg[0-9]+]]: // CHECK-SAME: [[RESULT:%arg[0-9]+]]: // CHECK-SAME: [[PRED:%arg[0-9]+]]: - tt.func @async_tma_gather(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, + tt.func @async_tma_gather(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #offsets>, %y_offset: i32, %bar: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, %result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, %pred: i1) { - // CHECK-NEXT: ttng.async_tma_gather [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[RESULT]], [[BAR]], [[PRED]] : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared2, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable>, i1 - ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 + // CHECK-NEXT: ttng.async_tma_gather [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[RESULT]], [[BAR]], [[PRED]] : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared2, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable>, i1 + ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #offsets>, i32, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 tt.return } @@ -81,10 +82,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-SAME: [[X_OFFSETS:%arg[0-9]+]]: // CHECK-SAME: [[Y_OFFSET:%arg[0-9]+]]: // CHECK-SAME: [[SRC:%arg[0-9]+]]: - tt.func @async_tma_scatter(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, + tt.func @async_tma_scatter(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #offsets>, %y_offset: i32, %src: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>) { - // CHECK-NEXT: ttng.async_tma_scatter [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[SRC]] : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable> - ttng.async_tma_scatter %desc[%x_offsets, %y_offset] %src : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable> + // CHECK-NEXT: ttng.async_tma_scatter [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[SRC]] : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable> + ttng.async_tma_scatter %desc[%x_offsets, %y_offset] %src : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #offsets>, i32, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable> tt.return } diff --git a/test/TritonNvidiaGPU/tma_lowering.mlir b/test/TritonNvidiaGPU/tma_lowering.mlir index ef87b0d65317..2ceb3fb14d31 100644 --- a/test/TritonNvidiaGPU/tma_lowering.mlir +++ b/test/TritonNvidiaGPU/tma_lowering.mlir @@ -70,14 +70,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- -#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#offsets = #ttg.slice<{dim = 0, parent = #blocked}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #nvmma_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @tma_gather -tt.func @tma_gather(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tensor<32xi32, #blocked>, %arg2: i32) -> tensor<32x128xbf16, #blocked1> { +tt.func @tma_gather(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tensor<32xi32, #offsets>, %arg2: i32) -> tensor<32x128xbf16, #blocked1> { // CHECK: [[RESULT:%.*]] = ttg.local_alloc // CHECK: [[BARRIER:%.*]] = ttg.local_alloc // CHECK: ttng.init_barrier [[BARRIER]] @@ -85,18 +86,18 @@ tt.func @tma_gather(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tensor // CHECK: ttng.wait_barrier [[BARRIER]] // CHECK: ttng.inval_barrier [[BARRIER]] // CHECK: [[OUT:%.*]] = ttg.local_load [[RESULT]] - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #blocked>, i32) -> tensor<32x128xbf16, #blocked1> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #offsets>, i32) -> tensor<32x128xbf16, #blocked1> // CHECK: return [[OUT]] tt.return %0 : tensor<32x128xbf16, #blocked1> } // CHECK-LABEL: @tma_scatter -tt.func @tma_scatter(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tensor<32xi32, #blocked>, %arg2: i32, %arg3: tensor<32x128xbf16, #blocked1>) { +tt.func @tma_scatter(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tensor<32xi32, #offsets>, %arg2: i32, %arg3: tensor<32x128xbf16, #blocked1>) { // CHECK-NEXT: [[SRC:%.*]] = ttg.local_alloc %arg3 // CHECK-NEXT: ttng.fence_async_shared {bCluster = false} // CHECK-NEXT: ttng.async_tma_scatter %arg0[%arg1, %arg2] [[SRC]] // CHECK-NEXT: ttng.async_tma_store_wait - tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #blocked>, i32, tensor<32x128xbf16, #blocked1> + tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #offsets>, i32, tensor<32x128xbf16, #blocked1> tt.return } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index ab64045830e5..4b4bd0b5417b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1101,6 +1101,7 @@ static LinearLayout getMsgToPackedOffsetLayout(ttg::MemDescType ty, ttg::TMAMode mode) { auto ctx = ty.getContext(); auto kMsg = str_attr("msg"); + auto kBlock = str_attr("block"); auto shapePerCTA = ttg::getShapePerCTA(ty); int rank = shapePerCTA.size(); auto blockShape = ttng::getTMABlockShape(ty, /*packedSize=*/true, mode); @@ -1210,9 +1211,9 @@ struct AsyncTMACopyGlobalToLocalOpConversion uint32_t barrierMask = toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock); - // We emit a cluster-level barrier when the barrier mask is set. - bool clusterBarrier = barrierMask != 0; - if (clusterBarrier) { + // Use a cross-CTA mbarrier pointer when the barrier mask is set. + bool crossCTABarrier = barrierMask != 0; + if (crossCTABarrier) { barrierPtr = LLVM::NVIDIA::getLeaderAddress(loc, rewriter, barrierPtr, barrierTy); } @@ -1240,7 +1241,7 @@ struct AsyncTMACopyGlobalToLocalOpConversion Value copyIdxVal = b.add(warpID, b.i32_val(copyIdx)); Value shMemOffset = applyLinearLayout(loc, rewriter, msgToShared, - {{kMsg, copyIdxVal}, {kBlock, zero}})[0] + {{kMsg, copyIdxVal}, {kBlock, ctaId}})[0] .second; Value shMemPtr = b.gep(elemPtrTy, llvmElemTy, dstBase, shMemOffset); SmallVector operands = { @@ -1249,7 +1250,7 @@ struct AsyncTMACopyGlobalToLocalOpConversion ptxBuilderTMA.newOperand(adaptor.getDesc(), "l")}; std::string tmaInst = "@$0 cp.async.bulk.tensor." + std::to_string(rank) + "d." + ctaGroup + - "shared::" + ((clusterBarrier || multicast) ? "cluster" : "cta") + + "shared::" + ((crossCTABarrier || multicast) ? "cluster" : "cta") + ".global" + (isIm2Col ? ".im2col" : "") + ".mbarrier::complete_tx::bytes"; if (multicast) @@ -1344,6 +1345,14 @@ LogicalResult convertTMAStoreLikeOp(Operation *op, auto numCopies = msgToOffset.getInDimSize(kMsg); auto zero = b.i32_val(0); auto ctaId = nvgpu::ClusterCTAIdOp::create(rewriter, loc); + uint32_t maskCGABroadcast = smemLayout.getFreeVariableMasks().lookup(kBlock); + if (maskCGABroadcast != 0) { + // Stores and reductions operate from CTA-local shared memory, so if the + // source tile is broadcast across CTAs only the lead CTA should issue the + // TMA message. + Value ctaIdInGroup = b.and_(ctaId, b.i32_val(maskCGABroadcast)); + pred = b.and_(pred, b.icmp_eq(ctaIdInGroup, zero)); + } for (int copyIdx = 0; copyIdx < numCopies; copyIdx += numWarps) { int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); @@ -1356,7 +1365,7 @@ LogicalResult convertTMAStoreLikeOp(Operation *op, Value copyIdxVal = b.add(warpID, b.i32_val(copyIdx)); Value shMemOffset = applyLinearLayout(loc, rewriter, msgToShared, - {{kMsg, copyIdxVal}, {kBlock, zero}})[0] + {{kMsg, copyIdxVal}, {kBlock, ctaId}})[0] .second; Value shMemPtr = b.gep(elemPtrTy, llvmElemTy, dstBase, shMemOffset); SmallVector operands = { @@ -1527,6 +1536,8 @@ static LogicalResult iterateGatherScatterIndices( LinearLayout msgToCol = LinearLayout::strided1D(numMessagesPerRow, msgSize, kMsg, kDim1); LinearLayout msgLayout = xCoordsLayout * msgToCol; + LinearLayout msgToOffset = + getMsgToPackedOffsetLayout(smemType, ttg::TMAMode::Tiled); // `gather4` will put the segments of the 4 rows consecutively in // shared memory. However, if the 4 rows are smaller than the shared memory @@ -1545,6 +1556,10 @@ static LogicalResult iterateGatherScatterIndices( Value warpId = mlir::triton::gpu::WarpIdOp::create(rewriter, loc); Value blockId = nvgpu::ClusterCTAIdOp::create(rewriter, loc); + auto ctaOffsets = applyLinearLayout( + loc, rewriter, msgToOffset, {{kMsg, b.i32_val(0)}, {kBlock, blockId}}); + assert(ctaOffsets.size() == 2 && ctaOffsets.back().first == kDim1); + yOffsetValue = b.add(yOffsetValue, ctaOffsets.back().second); // Mask out warps with redundant x offsets. pred = b.and_(pred, @@ -1598,19 +1613,55 @@ LogicalResult AsyncTMAGatherOpConversion::matchAndRewrite( triton::nvidia_gpu::AsyncTMAGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); LLVM::LLVMVoidType voidTy = void_ty(op->getContext()); + auto barrierTy = op.getBarrier().getType(); auto barrierMemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getBarrier(), - typeConverter->convertType(op.getBarrier().getType().getElementType()), - rewriter); + typeConverter->convertType(barrierTy.getElementType()), rewriter); + + auto kBlock = StringAttr::get(op.getContext(), "block"); + bool multicast = op.getMulticast(); + Value pred = adaptor.getPred(); + Value multicastMask; + if (multicast) { + uint32_t maskCGABroadcast = ttg::toLinearLayout(op.getResult().getType()) + .getFreeVariableMasks() + .lookup(kBlock); + multicastMask = + LLVM::NVIDIA::createTMAMulticastMask(loc, rewriter, maskCGABroadcast); + Value ctaId = nvgpu::ClusterCTAIdOp::create(rewriter, loc); + Value ctaIdInGroup = b.and_(ctaId, b.i32_val(maskCGABroadcast)); + pred = b.and_(pred, b.icmp_eq(ctaIdInGroup, b.i32_val(0))); + } + + Value barrierPtr = barrierMemObj.getBase(); + bool crossCTABarrier = + toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock) != 0; + if (crossCTABarrier) { + barrierPtr = + LLVM::NVIDIA::getLeaderAddress(loc, rewriter, barrierPtr, barrierTy); + } + + std::string ctaGroup; + if (getModuleTwoCTAs(op)) + ctaGroup = ".cta_group::2"; // Callback to generate the gather4 instruction. auto callback = [&](Value pred, Value shMemPtr, Value yOffset, ArrayRef xOffsets) { - std::string tmaInst = "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared" - "::cta.global.mbarrier::complete_tx::bytes " - "[$1], [$2, {$3, $4, $5, $6, $7}], [$8];"; + std::string tmaInst = "@$0 cp.async.bulk.tensor.2d.tile::gather4"; + tmaInst += ctaGroup; + tmaInst += ".shared::"; + tmaInst += (crossCTABarrier || multicast) ? "cluster" : "cta"; + tmaInst += ".global.mbarrier::complete_tx::bytes"; + if (multicast) + tmaInst += ".multicast::cluster"; + tmaInst += " [$1], [$2, {$3, $4, $5, $6, $7}], [$8]"; + if (multicast) + tmaInst += ", $9"; + tmaInst += ";"; PTXBuilder ptxBuilder; SmallVector operands{ @@ -1623,7 +1674,9 @@ LogicalResult AsyncTMAGatherOpConversion::matchAndRewrite( }; for (Value xOffset : xOffsets) operands.push_back(ptxBuilder.newOperand(xOffset, "r")); - operands.push_back(ptxBuilder.newOperand(barrierMemObj.getBase(), "r")); + operands.push_back(ptxBuilder.newOperand(barrierPtr, "r")); + if (multicast) + operands.push_back(ptxBuilder.newOperand(multicastMask, "h")); auto &tma = *ptxBuilder.create(tmaInst); tma(operands, /*attachOnlyMLIRArgs=*/true); @@ -1633,7 +1686,7 @@ LogicalResult AsyncTMAGatherOpConversion::matchAndRewrite( if (failed(iterateGatherScatterIndices( op, rewriter, *getTypeConverter(), op.getXOffsets(), op.getResult(), adaptor.getResult(), adaptor.getXOffsets(), adaptor.getYOffset(), - adaptor.getPred(), callback))) + pred, callback))) return failure(); rewriter.eraseOp(op); @@ -1655,6 +1708,18 @@ LogicalResult AsyncTMAScatterOpConversion::matchAndRewrite( Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); LLVM::LLVMVoidType voidTy = void_ty(op->getContext()); + auto kBlock = StringAttr::get(op.getContext(), "block"); + Value pred = b.true_val(); + uint32_t maskCGABroadcast = ttg::toLinearLayout(op.getSrc().getType()) + .getFreeVariableMasks() + .lookup(kBlock); + if (maskCGABroadcast != 0) { + // `scatter4` only reads from the current CTA's shared memory, so if the + // source is broadcast across CTAs only the lead CTA should issue it. + Value ctaId = nvgpu::ClusterCTAIdOp::create(rewriter, loc); + Value ctaIdInGroup = b.and_(ctaId, b.i32_val(maskCGABroadcast)); + pred = b.icmp_eq(ctaIdInGroup, b.i32_val(0)); + } // Callback to generate the scatter4 instruction. auto callback = [&](Value pred, Value shMemPtr, Value yOffset, @@ -1682,8 +1747,8 @@ LogicalResult AsyncTMAScatterOpConversion::matchAndRewrite( if (failed(iterateGatherScatterIndices( op, rewriter, *getTypeConverter(), op.getXOffsets(), op.getSrc(), - adaptor.getSrc(), adaptor.getXOffsets(), adaptor.getYOffset(), - /*pred=*/b.true_val(), callback))) + adaptor.getSrc(), adaptor.getXOffsets(), adaptor.getYOffset(), pred, + callback))) return failure(); // TODO: Separate the syncronizations operations into separate TTGIR ops to From 8523533d1dfc1dd1009b93f6b7e048e2a7c3357e Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 15 Apr 2026 18:09:34 +0200 Subject: [PATCH 2/3] better computation --- python/test/gluon/test_core.py | 6 ++--- .../experimental/gluon/language/_core.py | 23 ++++++++++++++++++- .../gluon/language/nvidia/blackwell/tma.py | 11 --------- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index ad2b8d39f8f1..168493b1a654 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -272,7 +272,7 @@ def tma_gather_scatter_kernel(in_desc, gather_out_desc, scatter_out_desc, gather mbarrier.init(bar, count=1) gather_offsets = ttgl.load(gather_idx_ptr + ttgl.arange(0, BLOCK_M, layout=x_offsets_layout)) - mbarrier.expect(bar, blackwell_tma.nbytes_per_cta_gather(in_desc, gather_offsets)) + mbarrier.expect(bar, smem.nbytes_per_cta) blackwell_tma.async_gather(in_desc, gather_offsets, 0, bar, smem, multicast=True) mbarrier.wait(bar, phase=0, deps=[smem]) @@ -770,9 +770,7 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, out_desc, gather_idx_p gather_offsets = ttgl.load(gather_idx_ptr + ttgl.arange(0, BLOCK_M, layout=gather_offsets_layout)) for k in range(NUM_K_TILES): - a_nbytes_per_cta: ttgl.constexpr = blackwell_tma.nbytes_per_cta_gather( - a_desc, gather_offsets) if use_gather_scatter else a_desc.nbytes_per_cta - mbarrier.expect(tma_bar, a_nbytes_per_cta + b_desc.nbytes_per_cta) + mbarrier.expect(tma_bar, smem_a.nbytes_per_cta + smem_b.nbytes_per_cta) if use_gather_scatter: blackwell_tma.async_gather(a_desc, gather_offsets, k * BLOCK_K, tma_bar, smem_a, multicast=multicast) else: diff --git a/python/triton/experimental/gluon/language/_core.py b/python/triton/experimental/gluon/language/_core.py index 3979492bc278..db8196fd0c96 100644 --- a/python/triton/experimental/gluon/language/_core.py +++ b/python/triton/experimental/gluon/language/_core.py @@ -9,7 +9,8 @@ from triton._C.libtriton.gluon_ir import GluonOpBuilder from ._semantic import GluonSemantic -from ._layouts import SharedLayout, DistributedLayout, BlockedLayout, DotOperandLayout, AutoLayout, CoalescedLayout +from ._layouts import (SharedLayout, DistributedLayout, BlockedLayout, DotOperandLayout, AutoLayout, CoalescedLayout, + SharedLinearLayout, _get_shape_per_cta) from triton._C.libtriton import ir import triton.language.core as tl_core from triton.language.core import ( @@ -207,6 +208,22 @@ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None def __str__(self) -> str: return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>" + @property + def nbytes_per_cta(self) -> int: + cga_layout = getattr(self.layout, "cga_layout", []) + if isinstance(self.layout, SharedLinearLayout): + cga_layout = [] + dim_bases = [0] * len(self.shape) + for basis in self.layout.block_bases: + cga_basis = [0] * len(self.shape) + for dim, value in enumerate(basis): + if value != 0: + cga_basis[dim] = 1 << dim_bases[dim] + dim_bases[dim] += 1 + cga_layout.append(cga_basis) + shape_per_cta = _get_shape_per_cta(self.shape, cga_layout) + return math.prod(shape_per_cta) * self.element_ty.primitive_bitwidth // 8 + def __eq__(self, other) -> bool: return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout and self.alloc_shape == other.alloc_shape) @@ -251,6 +268,10 @@ def rank(self): def numel(self) -> int: return math.prod(self.shape) + @property + def nbytes_per_cta(self) -> int: + return self.type.nbytes_per_cta + @property def layout(self): return self.type.layout diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py index 3bef1c0f1bf9..8bed8c74b83a 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py @@ -1,5 +1,4 @@ import triton.experimental.gluon.language._core as ttgl -from triton.experimental.gluon._runtime import constexpr_function from triton.experimental.gluon.language._core import builtin from triton.experimental.gluon.language.nvidia.hopper.tma import ( async_copy_global_to_shared, @@ -20,19 +19,9 @@ "tensor_descriptor", "tensor_descriptor_type", "make_tensor_descriptor", - "nbytes_per_cta_gather", ] -@constexpr_function -def nbytes_per_cta_gather(desc, offsets): - num_splits = 1 - for basis in offsets.type.layout.parent.cga_layout: - if basis != [0, 0]: - num_splits *= 2 - return offsets.shape[0] * desc.block_shape[1] * (desc.dtype.primitive_bitwidth // 8) // num_splits - - @builtin def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, multicast=False, _semantic=None): """ From 052a53ed62ae96a96df23725e4ef3d509c4f233b Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 16 Apr 2026 14:36:58 +0200 Subject: [PATCH 3/3] more explicit --- python/triton/experimental/gluon/language/_core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/triton/experimental/gluon/language/_core.py b/python/triton/experimental/gluon/language/_core.py index db8196fd0c96..eb1e1960afa9 100644 --- a/python/triton/experimental/gluon/language/_core.py +++ b/python/triton/experimental/gluon/language/_core.py @@ -210,7 +210,6 @@ def __str__(self) -> str: @property def nbytes_per_cta(self) -> int: - cga_layout = getattr(self.layout, "cga_layout", []) if isinstance(self.layout, SharedLinearLayout): cga_layout = [] dim_bases = [0] * len(self.shape) @@ -221,6 +220,8 @@ def nbytes_per_cta(self) -> int: cga_basis[dim] = 1 << dim_bases[dim] dim_bases[dim] += 1 cga_layout.append(cga_basis) + else: + cga_layout = self.layout.cga_layout shape_per_cta = _get_shape_per_cta(self.shape, cga_layout) return math.prod(shape_per_cta) * self.element_ty.primitive_bitwidth // 8