From 9a5ab0baf54187f0548249ddad61fe18bf41cdb0 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 1 May 2026 19:58:04 +0200 Subject: [PATCH 1/2] [MULTICTA] Fix multicast pattern for tcgen05_mma_scaled The helper we had was missing the necessary `two_ctas` flag to compute the number of arrivals correctly for an mma that goes into a multicast TMA. I changed the API all across I think it would be much better if we just had a TTNG_InitMmaBarrierOp as proposed in https://github.com/triton-lang/triton/pull/9957, as this would enable to get actual perf with multicast and would make the API much cleaner, but for now we just get this helper. What's nice is that this bug was found using the multicta consan :D --- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 3 +- python/examples/gluon/03-matmul-multicta.py | 17 +++++----- .../gluon/04-2cta-block-scale-matmul.py | 29 +++++++++------- python/test/gluon/test_core.py | 16 +++++---- .../language/nvidia/blackwell/__init__.py | 34 ++++++++----------- python/tutorials/gluon/14-multicta.py | 24 +++++++------ 6 files changed, 64 insertions(+), 59 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 5ca0be5663b9..2c99559c196b 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -928,8 +928,7 @@ ValueRange TCGen5MMAOp::getCompletionBarrierPreds() { static void appendMulticastDesc(SmallVectorImpl &descs, TypedValue desc) { - if (isa(desc.getType().getEncoding())) - descs.push_back(desc); + descs.push_back(desc); } SmallVector TCGen5MMAOp::getCompletionDescs() { diff --git a/python/examples/gluon/03-matmul-multicta.py b/python/examples/gluon/03-matmul-multicta.py index bae3a0737a95..3ad464525952 100644 --- a/python/examples/gluon/03-matmul-multicta.py +++ b/python/examples/gluon/03-matmul-multicta.py @@ -483,8 +483,16 @@ def _matmul_kernel( dtype: gl.constexpr = a_desc.dtype a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout) b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout) + tmem_layout: gl.constexpr = TensorMemoryLayout( + [BLOCK_SIZE_M, BLOCK_N // get_split_dim(CGA_LAYOUT, 1)], + col_stride=1, + cga_layout=CGA_LAYOUT, + two_ctas=TWO_CTAS, + ) + acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, BLOCK_M, BLOCK_N], tmem_layout) # Number of CTAs that will arrive on the barrier from a tcgen05_commit after an MMA instruction - mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True) + mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True, + two_ctas=acc_bufs.index(0).type.layout.two_ctas) # Equiv. consumed_barrier. Barrier TCGEN05 MMA -> Load TMA load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES) @@ -494,13 +502,6 @@ def _matmul_kernel( mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) mbarrier.init(load_ready_bars.index(i), count=1) - tmem_layout: gl.constexpr = TensorMemoryLayout( - [BLOCK_SIZE_M, BLOCK_N // get_split_dim(CGA_LAYOUT, 1)], - col_stride=1, - cga_layout=CGA_LAYOUT, - two_ctas=TWO_CTAS, - ) - acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, BLOCK_M, BLOCK_N], tmem_layout) # Equiv. store_done_barrier. Barrier Store TMA -> TCGEN05 MMA acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=TWO_CTAS) # Equiv. mma_done_barrier. Barrier TCGEN05 MMA -> Store TMA diff --git a/python/examples/gluon/04-2cta-block-scale-matmul.py b/python/examples/gluon/04-2cta-block-scale-matmul.py index 85babf241da5..cc074801d909 100644 --- a/python/examples/gluon/04-2cta-block-scale-matmul.py +++ b/python/examples/gluon/04-2cta-block-scale-matmul.py @@ -29,6 +29,7 @@ clc, tcgen05_copy, tcgen05_commit, + tcgen05_mma_barrier_count, tcgen05_mma_scaled, mbarrier, tma, @@ -235,7 +236,7 @@ def unswizzle_scales_shared_memory(smem, BLOCK_MN: gl.constexpr, BLOCK_K: gl.con @gluon.jit -def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, use_acc, pred): +def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, mma_bar, use_acc, pred): A_ELEM_PER_BYTE: gl.constexpr = 2 if a_smem.dtype == gl.uint8 else 1 BLOCK_M: gl.constexpr = a_smem.shape[0] BLOCK_N: gl.constexpr = b_smem.shape[0] @@ -259,7 +260,7 @@ def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, a_format: gl.constexpr = "e2m1" if a_smem.dtype == gl.uint8 else "e4m3" b_format: gl.constexpr = "e2m1" if b_smem.dtype == gl.uint8 else "e4m3" tcgen05_mma_scaled(a_smem, b_smem.permute((1, 0)), acc_tmem, a_scale_tmem, b_scale_tmem, a_format, b_format, - use_acc=use_acc, pred=pred) + use_acc=use_acc, pred=pred, multicast=True, mbarriers=[mma_bar]) # This helper function computes all the load indexing and issues the async loads @@ -272,7 +273,7 @@ def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, # clean, as pipelining can get messy. @gluon.jit def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale_desc, a_bufs, b_bufs, a_scale_bufs, - b_scale_bufs, bars, pred, multicast_b_scale: gl.constexpr = False): + b_scale_bufs, bars, pred): A_ELEM_PER_BYTE: gl.constexpr = 2 if a_desc.dtype == gl.uint8 else 1 B_ELEM_PER_BYTE: gl.constexpr = 2 if b_desc.dtype == gl.uint8 else 1 BLOCK_M: gl.constexpr = a_desc.block_shape[0] @@ -297,11 +298,12 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale mbarrier.expect( bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta + a_scale_desc.nbytes_per_cta + b_scale_desc.nbytes_per_cta, pred) - tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred) - tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred) - tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred) + tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred, multicast=True) + tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred, multicast=True) + tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred, + multicast=True) tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_bufs.index(index), pred, - multicast=multicast_b_scale) + multicast=True) return producer.next(pred) @@ -310,8 +312,7 @@ def issue_mma(consumer, c_bars, a_bufs, b_bufs, a_scale_bufs, b_scale_bufs, prod c_index = consumer.index mbarrier.wait(c_bars.index(c_index), consumer.phase, pred) async_mma_scaled_impl(a_bufs.index(c_index), b_bufs.index(c_index), a_scale_bufs.index(c_index), - b_scale_bufs.index(c_index), acc_tmem, use_acc, pred) - tcgen05_commit(p_bars.index(producer.index), pred) + b_scale_bufs.index(c_index), acc_tmem, p_bars.index(producer.index), use_acc, pred) return consumer.next(pred), producer.next(pred) @@ -496,7 +497,7 @@ def mma_scaled_load_partition(p): mbarrier.wait(p.load_empty_bars.index(state.index), state.phase) state = issue_loads(state, scheduler.pid_m, scheduler.pid_n, k, p.a_desc, p.b_desc, p.a_scale_desc, p.b_scale_desc, p.a_bufs, p.b_bufs, p.a_scale_bufs, p.b_scale_bufs, p.load_ready_bars, - pred=True, multicast_b_scale=gl.num_ctas() > 1) + pred=True) scheduler = scheduler.step(i) i += 1 @@ -620,11 +621,16 @@ def mma_scaled_warp_specialized_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_s tmem_layout: gl.constexpr = TensorMemoryLayout([BLOCK_M_PER_CTA, BLOCK_N], col_stride=1, cga_layout=CGA_LAYOUT, two_ctas=TWO_CTAS) + acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout) + + mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count( + [a_bufs.index(0), b_bufs.index(0), + a_scale_bufs.index(0), b_scale_bufs.index(0)], multicast=True, two_ctas=acc_bufs.index(0).type.layout.two_ctas) load_empty_bars = mbarrier.allocate_mbarrier(batch=num_buffers) load_ready_bars = mbarrier.allocate_mbarrier(batch=num_buffers, two_ctas=TWO_CTAS) for i in gl.static_range(num_buffers): - mbarrier.init(load_empty_bars.index(i), count=1) + mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) mbarrier.init(load_ready_bars.index(i), count=1) acc_empty_bars = mbarrier.allocate_mbarrier(batch=num_acc_buffers, two_ctas=TWO_CTAS) @@ -646,7 +652,6 @@ def mma_scaled_warp_specialized_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_s clc_result_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 2], clc_layout) clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout) - acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout) p = PartitionArgs(a_desc, b_desc, c_desc, a_scale_desc, b_scale_desc, a_bufs, b_bufs, a_scale_bufs, b_scale_bufs, load_empty_bars, load_ready_bars, acc_bufs, acc_empty_bars, acc_ready_bars, clc_result_buffers, clc_barriers, clc_planar_pid_buffers, clc_planar_ready_bars, clc_consumed_bars, GRID_MINOR_DIM, diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index d762c4e2c93c..cdcf54a2f2de 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -349,11 +349,12 @@ def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl. acc_tmem_layout: ttgl.constexpr, blocked_c: ttgl.constexpr): smem_a = ttgl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout) smem_b = ttgl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout) + acc_tmem = allocate_tensor_memory(ttgl.float32, [BLOCK_M, BLOCK_N], acc_tmem_layout) tma_bar = mbarrier.allocate_mbarrier(two_ctas=acc_tmem_layout.two_ctas) mbarrier.init(tma_bar, count=1) mma_bar = mbarrier.allocate_mbarrier() - mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True)) + mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True, acc_tmem.type.layout.two_ctas)) mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) tma.async_load(a_desc, [0, 0], tma_bar, smem_a, multicast=True) @@ -361,7 +362,6 @@ def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl. mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b]) mbarrier.invalidate(tma_bar) - acc_tmem = allocate_tensor_memory(ttgl.float32, [BLOCK_M, BLOCK_N], acc_tmem_layout) # If it's not in a loop we don't striclty need multicast=True, but we add it to exercise the path in the test tcgen05_mma(smem_a, smem_b, acc_tmem, use_acc=False, multicast=True, mbarriers=[mma_bar]) mbarrier.wait(mma_bar, phase=0, deps=[smem_a, smem_b]) @@ -482,7 +482,7 @@ def tcgen05_mma_scaled_direct_multicast_kernel(a_desc, b_desc, out_ptr, BLOCK_M: tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) mbarrier.init(tma_bar, count=1) mma_bar = mbarrier.allocate_mbarrier() - mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True)) + mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True, acc.type.layout.two_ctas)) phase_tma = 0 phase_mma = 0 @@ -753,14 +753,15 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, out_desc, gather_idx_p phase_tma = 0 if use_tcgen05: - mma_bar = mbarrier.allocate_mbarrier() - phase_mma = 0 - mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast)) acc_tmem = allocate_tensor_memory( element_ty=ttgl.float32, shape=[BLOCK_M, BLOCK_N], layout=acc_tmem_layout, ) + mma_bar = mbarrier.allocate_mbarrier() + phase_mma = 0 + mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast, + acc_tmem.type.layout.two_ctas)) else: acc = ttgl.zeros([BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=acc_layout) @@ -3943,7 +3944,8 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale mbarrier.init(tma_bar, count=1) mma_bar = mbarrier.allocate_mbarrier() - mma_bar_count: ttgl.constexpr = tcgen05_mma_barrier_count([a_smem, b_smem], multicast=multicast) + mma_bar_count: ttgl.constexpr = tcgen05_mma_barrier_count([a_smem, b_smem], multicast=multicast, + two_ctas=acc_tmem.type.layout.two_ctas) mbarrier.init(mma_bar, count=mma_bar_count) phase_tma = 0 diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index a9981d49580e..dd3d52502809 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -541,41 +541,37 @@ def tcgen05_mma_scaled(a, b, acc, a_scale, b_scale, a_type, b_type, *, use_acc=T @constexpr_function -def tcgen05_mma_barrier_count(smems, multicast): +def tcgen05_mma_barrier_count(smems, multicast, two_ctas): """ Calculate the number of CTAs that will commit the tcgen05 MMA instruction. Args: smems (Sequence[shared_memory_descriptor]): Shared memory descriptors used in the tcgen05 instruction. multicast (bool): Whether the tcgen05 instruction is multicast. + two_ctas (bool): Whether the tcgen05 instruction uses cta_group::2. Returns: int: The number of CTAs that will commit the tcgen05 MMA instruction. """ - assert 0 <= len(smems) <= 2, "tcgen05_mma_barrier_count supports 0, 1, or 2 smem descriptors" + assert 0 <= len(smems) <= 4, "tcgen05_mma_barrier_count supports 0 to 4 descriptors" if not smems or not multicast: return 1 def basis_is_zero(basis): return all(b == 0 for b in basis) - def num_broadcast_bits(smem): - return sum(basis_is_zero(basis) for basis in smem.layout.cga_layout) - - if len(smems) == 1: - return 2**num_broadcast_bits(smems[0]) - - assert len(smems) == 2 - num_broadcast_bits_a = num_broadcast_bits(smems[0]) - num_broadcast_bits_b = num_broadcast_bits(smems[1]) - # Assert that for every basis, at least one of them is non-zero - # so that the inclusion-exclusion principle below works - # This can be generalised if needed by substracting below 2**size_intersection - for i in range(len(smems[0].layout.cga_layout)): - assert not basis_is_zero(smems[0].layout.cga_layout[i]) or not basis_is_zero(smems[1].layout.cga_layout[i]) - - # Inclusion-exclusion - num_cta_commits = 2**num_broadcast_bits_a + 2**num_broadcast_bits_b - 1 + num_cta_bits = len(smems[0].layout.cga_layout) + for desc in smems[1:]: + assert len(desc.layout.cga_layout) == num_cta_bits + + num_cta_commits = 0 + for cta in range(1 << num_cta_bits): + if two_ctas and cta & 1: + continue + for desc in smems: + if all(basis_is_zero(basis) or not (cta & (1 << i)) for i, basis in enumerate(desc.layout.cga_layout)): + num_cta_commits += 1 + break return num_cta_commits diff --git a/python/tutorials/gluon/14-multicta.py b/python/tutorials/gluon/14-multicta.py index f4002417b4c8..dff94d271735 100644 --- a/python/tutorials/gluon/14-multicta.py +++ b/python/tutorials/gluon/14-multicta.py @@ -548,12 +548,13 @@ def tma_tcgen05_kernel(a_desc, b_desc, out_desc, NUM_K_TILES: gl.constexpr, acc_ smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout) smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout) + acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout) tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) mma_bar = mbarrier.allocate_mbarrier() mbarrier.init(tma_bar, count=1) - mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True)) - - acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout) + mbarrier.init( + mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True, + two_ctas=acc_tmem.type.layout.two_ctas)) phase_tma = 0 phase_mma = 0 @@ -972,14 +973,6 @@ def matmul_multicta_kernel( dtype: gl.constexpr = a_desc.dtype a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout) b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout) - mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True) - - load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES) - load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas) - for i in gl.static_range(STAGES): - mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) - mbarrier.init(load_ready_bars.index(i), count=1) - tmem_layout: gl.constexpr = TensorMemoryLayout( [BLOCK_SIZE_M, block_n // get_split_dim(CGA_LAYOUT, 1)], col_stride=1, @@ -987,6 +980,15 @@ def matmul_multicta_kernel( two_ctas=two_ctas, ) acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, block_m, block_n], tmem_layout) + mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True, + two_ctas=acc_bufs.index(0).type.layout.two_ctas) + + load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES) + load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas) + for i in gl.static_range(STAGES): + mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) + mbarrier.init(load_ready_bars.index(i), count=1) + acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas) acc_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES) for i in gl.static_range(ACC_STAGES): From fb4b9b379600cdc1364df0b561822f880fc81c27 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 1 May 2026 21:19:35 +0200 Subject: [PATCH 2/2] fix --- python/test/gluon/test_consan.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 5e9baee683cc..1e28aaeebdf6 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -1159,12 +1159,13 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr): ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, True)), ) + acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) mbarrier.init(tma_bar, count=1) mma_bar = mbarrier.allocate_mbarrier() - mma_bar_count: ttgl.constexpr = blackwell.tcgen05_mma_barrier_count([smemA, smemB], True) + mma_bar_count: ttgl.constexpr = blackwell.tcgen05_mma_barrier_count([smemA, smemB], True, + acc.type.layout.two_ctas) mbarrier.init(mma_bar, count=mma_bar_count) - acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) phase_tma = 0 phase_mma = 0