diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 5ca0be5663b9..2c99559c196b 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -928,8 +928,7 @@ ValueRange TCGen5MMAOp::getCompletionBarrierPreds() { static void appendMulticastDesc(SmallVectorImpl &descs, TypedValue desc) { - if (isa(desc.getType().getEncoding())) - descs.push_back(desc); + descs.push_back(desc); } SmallVector TCGen5MMAOp::getCompletionDescs() { diff --git a/python/examples/gluon/03-matmul-multicta.py b/python/examples/gluon/03-matmul-multicta.py index bae3a0737a95..3ad464525952 100644 --- a/python/examples/gluon/03-matmul-multicta.py +++ b/python/examples/gluon/03-matmul-multicta.py @@ -483,8 +483,16 @@ def _matmul_kernel( dtype: gl.constexpr = a_desc.dtype a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout) b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout) + tmem_layout: gl.constexpr = TensorMemoryLayout( + [BLOCK_SIZE_M, BLOCK_N // get_split_dim(CGA_LAYOUT, 1)], + col_stride=1, + cga_layout=CGA_LAYOUT, + two_ctas=TWO_CTAS, + ) + acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, BLOCK_M, BLOCK_N], tmem_layout) # Number of CTAs that will arrive on the barrier from a tcgen05_commit after an MMA instruction - mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True) + mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True, + two_ctas=acc_bufs.index(0).type.layout.two_ctas) # Equiv. consumed_barrier. Barrier TCGEN05 MMA -> Load TMA load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES) @@ -494,13 +502,6 @@ def _matmul_kernel( mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) mbarrier.init(load_ready_bars.index(i), count=1) - tmem_layout: gl.constexpr = TensorMemoryLayout( - [BLOCK_SIZE_M, BLOCK_N // get_split_dim(CGA_LAYOUT, 1)], - col_stride=1, - cga_layout=CGA_LAYOUT, - two_ctas=TWO_CTAS, - ) - acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, BLOCK_M, BLOCK_N], tmem_layout) # Equiv. store_done_barrier. Barrier Store TMA -> TCGEN05 MMA acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=TWO_CTAS) # Equiv. mma_done_barrier. Barrier TCGEN05 MMA -> Store TMA diff --git a/python/examples/gluon/04-2cta-block-scale-matmul.py b/python/examples/gluon/04-2cta-block-scale-matmul.py index 85babf241da5..cc074801d909 100644 --- a/python/examples/gluon/04-2cta-block-scale-matmul.py +++ b/python/examples/gluon/04-2cta-block-scale-matmul.py @@ -29,6 +29,7 @@ clc, tcgen05_copy, tcgen05_commit, + tcgen05_mma_barrier_count, tcgen05_mma_scaled, mbarrier, tma, @@ -235,7 +236,7 @@ def unswizzle_scales_shared_memory(smem, BLOCK_MN: gl.constexpr, BLOCK_K: gl.con @gluon.jit -def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, use_acc, pred): +def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, mma_bar, use_acc, pred): A_ELEM_PER_BYTE: gl.constexpr = 2 if a_smem.dtype == gl.uint8 else 1 BLOCK_M: gl.constexpr = a_smem.shape[0] BLOCK_N: gl.constexpr = b_smem.shape[0] @@ -259,7 +260,7 @@ def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, a_format: gl.constexpr = "e2m1" if a_smem.dtype == gl.uint8 else "e4m3" b_format: gl.constexpr = "e2m1" if b_smem.dtype == gl.uint8 else "e4m3" tcgen05_mma_scaled(a_smem, b_smem.permute((1, 0)), acc_tmem, a_scale_tmem, b_scale_tmem, a_format, b_format, - use_acc=use_acc, pred=pred) + use_acc=use_acc, pred=pred, multicast=True, mbarriers=[mma_bar]) # This helper function computes all the load indexing and issues the async loads @@ -272,7 +273,7 @@ def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, # clean, as pipelining can get messy. @gluon.jit def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale_desc, a_bufs, b_bufs, a_scale_bufs, - b_scale_bufs, bars, pred, multicast_b_scale: gl.constexpr = False): + b_scale_bufs, bars, pred): A_ELEM_PER_BYTE: gl.constexpr = 2 if a_desc.dtype == gl.uint8 else 1 B_ELEM_PER_BYTE: gl.constexpr = 2 if b_desc.dtype == gl.uint8 else 1 BLOCK_M: gl.constexpr = a_desc.block_shape[0] @@ -297,11 +298,12 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale mbarrier.expect( bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta + a_scale_desc.nbytes_per_cta + b_scale_desc.nbytes_per_cta, pred) - tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred) - tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred) - tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred) + tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred, multicast=True) + tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred, multicast=True) + tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred, + multicast=True) tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_bufs.index(index), pred, - multicast=multicast_b_scale) + multicast=True) return producer.next(pred) @@ -310,8 +312,7 @@ def issue_mma(consumer, c_bars, a_bufs, b_bufs, a_scale_bufs, b_scale_bufs, prod c_index = consumer.index mbarrier.wait(c_bars.index(c_index), consumer.phase, pred) async_mma_scaled_impl(a_bufs.index(c_index), b_bufs.index(c_index), a_scale_bufs.index(c_index), - b_scale_bufs.index(c_index), acc_tmem, use_acc, pred) - tcgen05_commit(p_bars.index(producer.index), pred) + b_scale_bufs.index(c_index), acc_tmem, p_bars.index(producer.index), use_acc, pred) return consumer.next(pred), producer.next(pred) @@ -496,7 +497,7 @@ def mma_scaled_load_partition(p): mbarrier.wait(p.load_empty_bars.index(state.index), state.phase) state = issue_loads(state, scheduler.pid_m, scheduler.pid_n, k, p.a_desc, p.b_desc, p.a_scale_desc, p.b_scale_desc, p.a_bufs, p.b_bufs, p.a_scale_bufs, p.b_scale_bufs, p.load_ready_bars, - pred=True, multicast_b_scale=gl.num_ctas() > 1) + pred=True) scheduler = scheduler.step(i) i += 1 @@ -620,11 +621,16 @@ def mma_scaled_warp_specialized_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_s tmem_layout: gl.constexpr = TensorMemoryLayout([BLOCK_M_PER_CTA, BLOCK_N], col_stride=1, cga_layout=CGA_LAYOUT, two_ctas=TWO_CTAS) + acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout) + + mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count( + [a_bufs.index(0), b_bufs.index(0), + a_scale_bufs.index(0), b_scale_bufs.index(0)], multicast=True, two_ctas=acc_bufs.index(0).type.layout.two_ctas) load_empty_bars = mbarrier.allocate_mbarrier(batch=num_buffers) load_ready_bars = mbarrier.allocate_mbarrier(batch=num_buffers, two_ctas=TWO_CTAS) for i in gl.static_range(num_buffers): - mbarrier.init(load_empty_bars.index(i), count=1) + mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) mbarrier.init(load_ready_bars.index(i), count=1) acc_empty_bars = mbarrier.allocate_mbarrier(batch=num_acc_buffers, two_ctas=TWO_CTAS) @@ -646,7 +652,6 @@ def mma_scaled_warp_specialized_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_s clc_result_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 2], clc_layout) clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout) - acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout) p = PartitionArgs(a_desc, b_desc, c_desc, a_scale_desc, b_scale_desc, a_bufs, b_bufs, a_scale_bufs, b_scale_bufs, load_empty_bars, load_ready_bars, acc_bufs, acc_empty_bars, acc_ready_bars, clc_result_buffers, clc_barriers, clc_planar_pid_buffers, clc_planar_ready_bars, clc_consumed_bars, GRID_MINOR_DIM, diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 5e9baee683cc..1e28aaeebdf6 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -1159,12 +1159,13 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr): ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, True)), ) + acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) mbarrier.init(tma_bar, count=1) mma_bar = mbarrier.allocate_mbarrier() - mma_bar_count: ttgl.constexpr = blackwell.tcgen05_mma_barrier_count([smemA, smemB], True) + mma_bar_count: ttgl.constexpr = blackwell.tcgen05_mma_barrier_count([smemA, smemB], True, + acc.type.layout.two_ctas) mbarrier.init(mma_bar, count=mma_bar_count) - acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) phase_tma = 0 phase_mma = 0 diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index d762c4e2c93c..cdcf54a2f2de 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -349,11 +349,12 @@ def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl. acc_tmem_layout: ttgl.constexpr, blocked_c: ttgl.constexpr): smem_a = ttgl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout) smem_b = ttgl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout) + acc_tmem = allocate_tensor_memory(ttgl.float32, [BLOCK_M, BLOCK_N], acc_tmem_layout) tma_bar = mbarrier.allocate_mbarrier(two_ctas=acc_tmem_layout.two_ctas) mbarrier.init(tma_bar, count=1) mma_bar = mbarrier.allocate_mbarrier() - mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True)) + mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True, acc_tmem.type.layout.two_ctas)) mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) tma.async_load(a_desc, [0, 0], tma_bar, smem_a, multicast=True) @@ -361,7 +362,6 @@ def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl. mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b]) mbarrier.invalidate(tma_bar) - acc_tmem = allocate_tensor_memory(ttgl.float32, [BLOCK_M, BLOCK_N], acc_tmem_layout) # If it's not in a loop we don't striclty need multicast=True, but we add it to exercise the path in the test tcgen05_mma(smem_a, smem_b, acc_tmem, use_acc=False, multicast=True, mbarriers=[mma_bar]) mbarrier.wait(mma_bar, phase=0, deps=[smem_a, smem_b]) @@ -482,7 +482,7 @@ def tcgen05_mma_scaled_direct_multicast_kernel(a_desc, b_desc, out_ptr, BLOCK_M: tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) mbarrier.init(tma_bar, count=1) mma_bar = mbarrier.allocate_mbarrier() - mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True)) + mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True, acc.type.layout.two_ctas)) phase_tma = 0 phase_mma = 0 @@ -753,14 +753,15 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, out_desc, gather_idx_p phase_tma = 0 if use_tcgen05: - mma_bar = mbarrier.allocate_mbarrier() - phase_mma = 0 - mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast)) acc_tmem = allocate_tensor_memory( element_ty=ttgl.float32, shape=[BLOCK_M, BLOCK_N], layout=acc_tmem_layout, ) + mma_bar = mbarrier.allocate_mbarrier() + phase_mma = 0 + mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast, + acc_tmem.type.layout.two_ctas)) else: acc = ttgl.zeros([BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=acc_layout) @@ -3943,7 +3944,8 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale mbarrier.init(tma_bar, count=1) mma_bar = mbarrier.allocate_mbarrier() - mma_bar_count: ttgl.constexpr = tcgen05_mma_barrier_count([a_smem, b_smem], multicast=multicast) + mma_bar_count: ttgl.constexpr = tcgen05_mma_barrier_count([a_smem, b_smem], multicast=multicast, + two_ctas=acc_tmem.type.layout.two_ctas) mbarrier.init(mma_bar, count=mma_bar_count) phase_tma = 0 diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index a9981d49580e..dd3d52502809 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -541,41 +541,37 @@ def tcgen05_mma_scaled(a, b, acc, a_scale, b_scale, a_type, b_type, *, use_acc=T @constexpr_function -def tcgen05_mma_barrier_count(smems, multicast): +def tcgen05_mma_barrier_count(smems, multicast, two_ctas): """ Calculate the number of CTAs that will commit the tcgen05 MMA instruction. Args: smems (Sequence[shared_memory_descriptor]): Shared memory descriptors used in the tcgen05 instruction. multicast (bool): Whether the tcgen05 instruction is multicast. + two_ctas (bool): Whether the tcgen05 instruction uses cta_group::2. Returns: int: The number of CTAs that will commit the tcgen05 MMA instruction. """ - assert 0 <= len(smems) <= 2, "tcgen05_mma_barrier_count supports 0, 1, or 2 smem descriptors" + assert 0 <= len(smems) <= 4, "tcgen05_mma_barrier_count supports 0 to 4 descriptors" if not smems or not multicast: return 1 def basis_is_zero(basis): return all(b == 0 for b in basis) - def num_broadcast_bits(smem): - return sum(basis_is_zero(basis) for basis in smem.layout.cga_layout) - - if len(smems) == 1: - return 2**num_broadcast_bits(smems[0]) - - assert len(smems) == 2 - num_broadcast_bits_a = num_broadcast_bits(smems[0]) - num_broadcast_bits_b = num_broadcast_bits(smems[1]) - # Assert that for every basis, at least one of them is non-zero - # so that the inclusion-exclusion principle below works - # This can be generalised if needed by substracting below 2**size_intersection - for i in range(len(smems[0].layout.cga_layout)): - assert not basis_is_zero(smems[0].layout.cga_layout[i]) or not basis_is_zero(smems[1].layout.cga_layout[i]) - - # Inclusion-exclusion - num_cta_commits = 2**num_broadcast_bits_a + 2**num_broadcast_bits_b - 1 + num_cta_bits = len(smems[0].layout.cga_layout) + for desc in smems[1:]: + assert len(desc.layout.cga_layout) == num_cta_bits + + num_cta_commits = 0 + for cta in range(1 << num_cta_bits): + if two_ctas and cta & 1: + continue + for desc in smems: + if all(basis_is_zero(basis) or not (cta & (1 << i)) for i, basis in enumerate(desc.layout.cga_layout)): + num_cta_commits += 1 + break return num_cta_commits diff --git a/python/tutorials/gluon/14-multicta.py b/python/tutorials/gluon/14-multicta.py index f4002417b4c8..dff94d271735 100644 --- a/python/tutorials/gluon/14-multicta.py +++ b/python/tutorials/gluon/14-multicta.py @@ -548,12 +548,13 @@ def tma_tcgen05_kernel(a_desc, b_desc, out_desc, NUM_K_TILES: gl.constexpr, acc_ smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout) smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout) + acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout) tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) mma_bar = mbarrier.allocate_mbarrier() mbarrier.init(tma_bar, count=1) - mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True)) - - acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout) + mbarrier.init( + mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True, + two_ctas=acc_tmem.type.layout.two_ctas)) phase_tma = 0 phase_mma = 0 @@ -972,14 +973,6 @@ def matmul_multicta_kernel( dtype: gl.constexpr = a_desc.dtype a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout) b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout) - mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True) - - load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES) - load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas) - for i in gl.static_range(STAGES): - mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) - mbarrier.init(load_ready_bars.index(i), count=1) - tmem_layout: gl.constexpr = TensorMemoryLayout( [BLOCK_SIZE_M, block_n // get_split_dim(CGA_LAYOUT, 1)], col_stride=1, @@ -987,6 +980,15 @@ def matmul_multicta_kernel( two_ctas=two_ctas, ) acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, block_m, block_n], tmem_layout) + mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True, + two_ctas=acc_bufs.index(0).type.layout.two_ctas) + + load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES) + load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas) + for i in gl.static_range(STAGES): + mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) + mbarrier.init(load_ready_bars.index(i), count=1) + acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas) acc_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES) for i in gl.static_range(ACC_STAGES):