Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,7 @@ ValueRange TCGen5MMAOp::getCompletionBarrierPreds() {

static void appendMulticastDesc(SmallVectorImpl<Value> &descs,
TypedValue<MemDescType> desc) {
if (isa<SharedEncodingTrait>(desc.getType().getEncoding()))
descs.push_back(desc);
descs.push_back(desc);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a LIT test to ensure there is no regression?

}

SmallVector<Value> TCGen5MMAOp::getCompletionDescs() {
Expand Down
17 changes: 9 additions & 8 deletions python/examples/gluon/03-matmul-multicta.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,16 @@ def _matmul_kernel(
dtype: gl.constexpr = a_desc.dtype
a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout)
tmem_layout: gl.constexpr = TensorMemoryLayout(
[BLOCK_SIZE_M, BLOCK_N // get_split_dim(CGA_LAYOUT, 1)],
col_stride=1,
cga_layout=CGA_LAYOUT,
two_ctas=TWO_CTAS,
)
acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, BLOCK_M, BLOCK_N], tmem_layout)
# Number of CTAs that will arrive on the barrier from a tcgen05_commit after an MMA instruction
mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True)
mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True,
two_ctas=acc_bufs.index(0).type.layout.two_ctas)

# Equiv. consumed_barrier. Barrier TCGEN05 MMA -> Load TMA
load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES)
Expand All @@ -494,13 +502,6 @@ def _matmul_kernel(
mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count)
mbarrier.init(load_ready_bars.index(i), count=1)

tmem_layout: gl.constexpr = TensorMemoryLayout(
[BLOCK_SIZE_M, BLOCK_N // get_split_dim(CGA_LAYOUT, 1)],
col_stride=1,
cga_layout=CGA_LAYOUT,
two_ctas=TWO_CTAS,
)
acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, BLOCK_M, BLOCK_N], tmem_layout)
# Equiv. store_done_barrier. Barrier Store TMA -> TCGEN05 MMA
acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=TWO_CTAS)
# Equiv. mma_done_barrier. Barrier TCGEN05 MMA -> Store TMA
Expand Down
29 changes: 17 additions & 12 deletions python/examples/gluon/04-2cta-block-scale-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
clc,
tcgen05_copy,
tcgen05_commit,
tcgen05_mma_barrier_count,
tcgen05_mma_scaled,
mbarrier,
tma,
Expand Down Expand Up @@ -235,7 +236,7 @@ def unswizzle_scales_shared_memory(smem, BLOCK_MN: gl.constexpr, BLOCK_K: gl.con


@gluon.jit
def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, use_acc, pred):
def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, mma_bar, use_acc, pred):
A_ELEM_PER_BYTE: gl.constexpr = 2 if a_smem.dtype == gl.uint8 else 1
BLOCK_M: gl.constexpr = a_smem.shape[0]
BLOCK_N: gl.constexpr = b_smem.shape[0]
Expand All @@ -259,7 +260,7 @@ def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem,
a_format: gl.constexpr = "e2m1" if a_smem.dtype == gl.uint8 else "e4m3"
b_format: gl.constexpr = "e2m1" if b_smem.dtype == gl.uint8 else "e4m3"
tcgen05_mma_scaled(a_smem, b_smem.permute((1, 0)), acc_tmem, a_scale_tmem, b_scale_tmem, a_format, b_format,
use_acc=use_acc, pred=pred)
use_acc=use_acc, pred=pred, multicast=True, mbarriers=[mma_bar])


# This helper function computes all the load indexing and issues the async loads
Expand All @@ -272,7 +273,7 @@ def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem,
# clean, as pipelining can get messy.
@gluon.jit
def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale_desc, a_bufs, b_bufs, a_scale_bufs,
b_scale_bufs, bars, pred, multicast_b_scale: gl.constexpr = False):
b_scale_bufs, bars, pred):
A_ELEM_PER_BYTE: gl.constexpr = 2 if a_desc.dtype == gl.uint8 else 1
B_ELEM_PER_BYTE: gl.constexpr = 2 if b_desc.dtype == gl.uint8 else 1
BLOCK_M: gl.constexpr = a_desc.block_shape[0]
Expand All @@ -297,11 +298,12 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale
mbarrier.expect(
bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta + a_scale_desc.nbytes_per_cta + b_scale_desc.nbytes_per_cta,
pred)
tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred)
tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred)
tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred)
tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred, multicast=True)
tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred, multicast=True)
tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred,
multicast=True)
Comment thread
adstraw marked this conversation as resolved.
tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_bufs.index(index), pred,
multicast=multicast_b_scale)
multicast=True)
return producer.next(pred)


Expand All @@ -310,8 +312,7 @@ def issue_mma(consumer, c_bars, a_bufs, b_bufs, a_scale_bufs, b_scale_bufs, prod
c_index = consumer.index
mbarrier.wait(c_bars.index(c_index), consumer.phase, pred)
async_mma_scaled_impl(a_bufs.index(c_index), b_bufs.index(c_index), a_scale_bufs.index(c_index),
b_scale_bufs.index(c_index), acc_tmem, use_acc, pred)
tcgen05_commit(p_bars.index(producer.index), pred)
b_scale_bufs.index(c_index), acc_tmem, p_bars.index(producer.index), use_acc, pred)
return consumer.next(pred), producer.next(pred)


Expand Down Expand Up @@ -496,7 +497,7 @@ def mma_scaled_load_partition(p):
mbarrier.wait(p.load_empty_bars.index(state.index), state.phase)
state = issue_loads(state, scheduler.pid_m, scheduler.pid_n, k, p.a_desc, p.b_desc, p.a_scale_desc,
p.b_scale_desc, p.a_bufs, p.b_bufs, p.a_scale_bufs, p.b_scale_bufs, p.load_ready_bars,
pred=True, multicast_b_scale=gl.num_ctas() > 1)
pred=True)
scheduler = scheduler.step(i)
i += 1

Expand Down Expand Up @@ -620,11 +621,16 @@ def mma_scaled_warp_specialized_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_s

tmem_layout: gl.constexpr = TensorMemoryLayout([BLOCK_M_PER_CTA, BLOCK_N], col_stride=1, cga_layout=CGA_LAYOUT,
two_ctas=TWO_CTAS)
acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout)

mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count(
[a_bufs.index(0), b_bufs.index(0),
a_scale_bufs.index(0), b_scale_bufs.index(0)], multicast=True, two_ctas=acc_bufs.index(0).type.layout.two_ctas)

load_empty_bars = mbarrier.allocate_mbarrier(batch=num_buffers)
load_ready_bars = mbarrier.allocate_mbarrier(batch=num_buffers, two_ctas=TWO_CTAS)
for i in gl.static_range(num_buffers):
mbarrier.init(load_empty_bars.index(i), count=1)
mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count)
mbarrier.init(load_ready_bars.index(i), count=1)

acc_empty_bars = mbarrier.allocate_mbarrier(batch=num_acc_buffers, two_ctas=TWO_CTAS)
Expand All @@ -646,7 +652,6 @@ def mma_scaled_warp_specialized_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_s
clc_result_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 2], clc_layout)
clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout)

acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout)
p = PartitionArgs(a_desc, b_desc, c_desc, a_scale_desc, b_scale_desc, a_bufs, b_bufs, a_scale_bufs, b_scale_bufs,
load_empty_bars, load_ready_bars, acc_bufs, acc_empty_bars, acc_ready_bars, clc_result_buffers,
clc_barriers, clc_planar_pid_buffers, clc_planar_ready_bars, clc_consumed_bars, GRID_MINOR_DIM,
Expand Down
5 changes: 3 additions & 2 deletions python/test/gluon/test_consan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,12 +1159,13 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr):
ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16,
cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, True)),
)
acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout)
tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
mbarrier.init(tma_bar, count=1)
mma_bar = mbarrier.allocate_mbarrier()
mma_bar_count: ttgl.constexpr = blackwell.tcgen05_mma_barrier_count([smemA, smemB], True)
mma_bar_count: ttgl.constexpr = blackwell.tcgen05_mma_barrier_count([smemA, smemB], True,
acc.type.layout.two_ctas)
mbarrier.init(mma_bar, count=mma_bar_count)
acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout)

phase_tma = 0
phase_mma = 0
Expand Down
16 changes: 9 additions & 7 deletions python/test/gluon/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,19 +349,19 @@ def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl.
acc_tmem_layout: ttgl.constexpr, blocked_c: ttgl.constexpr):
smem_a = ttgl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
smem_b = ttgl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)
acc_tmem = allocate_tensor_memory(ttgl.float32, [BLOCK_M, BLOCK_N], acc_tmem_layout)

tma_bar = mbarrier.allocate_mbarrier(two_ctas=acc_tmem_layout.two_ctas)
mbarrier.init(tma_bar, count=1)
mma_bar = mbarrier.allocate_mbarrier()
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True))
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True, acc_tmem.type.layout.two_ctas))

mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_load(a_desc, [0, 0], tma_bar, smem_a, multicast=True)
tma.async_load(b_desc, [0, 0], tma_bar, smem_b, multicast=True)
mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b])
mbarrier.invalidate(tma_bar)

acc_tmem = allocate_tensor_memory(ttgl.float32, [BLOCK_M, BLOCK_N], acc_tmem_layout)
# If it's not in a loop we don't striclty need multicast=True, but we add it to exercise the path in the test
tcgen05_mma(smem_a, smem_b, acc_tmem, use_acc=False, multicast=True, mbarriers=[mma_bar])
mbarrier.wait(mma_bar, phase=0, deps=[smem_a, smem_b])
Expand Down Expand Up @@ -482,7 +482,7 @@ def tcgen05_mma_scaled_direct_multicast_kernel(a_desc, b_desc, out_ptr, BLOCK_M:
tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
mbarrier.init(tma_bar, count=1)
mma_bar = mbarrier.allocate_mbarrier()
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True))
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True, acc.type.layout.two_ctas))

phase_tma = 0
phase_mma = 0
Expand Down Expand Up @@ -753,14 +753,15 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, out_desc, gather_idx_p
phase_tma = 0

if use_tcgen05:
mma_bar = mbarrier.allocate_mbarrier()
phase_mma = 0
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast))
acc_tmem = allocate_tensor_memory(
element_ty=ttgl.float32,
shape=[BLOCK_M, BLOCK_N],
layout=acc_tmem_layout,
)
mma_bar = mbarrier.allocate_mbarrier()
phase_mma = 0
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast,
acc_tmem.type.layout.two_ctas))
else:
acc = ttgl.zeros([BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=acc_layout)

Expand Down Expand Up @@ -3943,7 +3944,8 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale
mbarrier.init(tma_bar, count=1)

mma_bar = mbarrier.allocate_mbarrier()
mma_bar_count: ttgl.constexpr = tcgen05_mma_barrier_count([a_smem, b_smem], multicast=multicast)
mma_bar_count: ttgl.constexpr = tcgen05_mma_barrier_count([a_smem, b_smem], multicast=multicast,
two_ctas=acc_tmem.type.layout.two_ctas)
mbarrier.init(mma_bar, count=mma_bar_count)

phase_tma = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,41 +541,37 @@ def tcgen05_mma_scaled(a, b, acc, a_scale, b_scale, a_type, b_type, *, use_acc=T


@constexpr_function
def tcgen05_mma_barrier_count(smems, multicast):
def tcgen05_mma_barrier_count(smems, multicast, two_ctas):
Comment thread
lezcano marked this conversation as resolved.
"""
Calculate the number of CTAs that will commit the tcgen05 MMA instruction.

Args:
smems (Sequence[shared_memory_descriptor]): Shared memory descriptors used in the tcgen05 instruction.
multicast (bool): Whether the tcgen05 instruction is multicast.
two_ctas (bool): Whether the tcgen05 instruction uses cta_group::2.

Returns:
int: The number of CTAs that will commit the tcgen05 MMA instruction.
"""
assert 0 <= len(smems) <= 2, "tcgen05_mma_barrier_count supports 0, 1, or 2 smem descriptors"
assert 0 <= len(smems) <= 4, "tcgen05_mma_barrier_count supports 0 to 4 descriptors"
if not smems or not multicast:
return 1

def basis_is_zero(basis):
return all(b == 0 for b in basis)

def num_broadcast_bits(smem):
return sum(basis_is_zero(basis) for basis in smem.layout.cga_layout)

if len(smems) == 1:
return 2**num_broadcast_bits(smems[0])

assert len(smems) == 2
num_broadcast_bits_a = num_broadcast_bits(smems[0])
num_broadcast_bits_b = num_broadcast_bits(smems[1])
# Assert that for every basis, at least one of them is non-zero
# so that the inclusion-exclusion principle below works
# This can be generalised if needed by substracting below 2**size_intersection
for i in range(len(smems[0].layout.cga_layout)):
assert not basis_is_zero(smems[0].layout.cga_layout[i]) or not basis_is_zero(smems[1].layout.cga_layout[i])

# Inclusion-exclusion
num_cta_commits = 2**num_broadcast_bits_a + 2**num_broadcast_bits_b - 1
num_cta_bits = len(smems[0].layout.cga_layout)
for desc in smems[1:]:
assert len(desc.layout.cga_layout) == num_cta_bits

num_cta_commits = 0
for cta in range(1 << num_cta_bits):
if two_ctas and cta & 1:
continue
for desc in smems:
if all(basis_is_zero(basis) or not (cta & (1 << i)) for i, basis in enumerate(desc.layout.cga_layout)):
num_cta_commits += 1
break
return num_cta_commits


Expand Down
24 changes: 13 additions & 11 deletions python/tutorials/gluon/14-multicta.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,12 +548,13 @@ def tma_tcgen05_kernel(a_desc, b_desc, out_desc, NUM_K_TILES: gl.constexpr, acc_
smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)

acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout)
tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
mma_bar = mbarrier.allocate_mbarrier()
mbarrier.init(tma_bar, count=1)
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True))

acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout)
mbarrier.init(
mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True,
two_ctas=acc_tmem.type.layout.two_ctas))

phase_tma = 0
phase_mma = 0
Expand Down Expand Up @@ -972,21 +973,22 @@ def matmul_multicta_kernel(
dtype: gl.constexpr = a_desc.dtype
a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout)
mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True)

load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES)
load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas)
for i in gl.static_range(STAGES):
mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count)
mbarrier.init(load_ready_bars.index(i), count=1)

tmem_layout: gl.constexpr = TensorMemoryLayout(
[BLOCK_SIZE_M, block_n // get_split_dim(CGA_LAYOUT, 1)],
col_stride=1,
cga_layout=CGA_LAYOUT,
two_ctas=two_ctas,
)
acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, block_m, block_n], tmem_layout)
mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True,
two_ctas=acc_bufs.index(0).type.layout.two_ctas)

load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES)
load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas)
for i in gl.static_range(STAGES):
mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count)
mbarrier.init(load_ready_bars.index(i), count=1)

acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas)
acc_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
for i in gl.static_range(ACC_STAGES):
Expand Down
Loading