Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ class DialectVerifyTensorLayoutInterface
};

// Descriptor gather and scatter have restrictions on the tile sizes.
LogicalResult verifyGatherScatterResultType(Operation *op,
ShapedType resultType,
ShapedType indicesType);
LogicalResult verifyGatherScatterOp(Operation *op, ShapedType blockType,
ShapedType resultType,
ShapedType indicesType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather", [
I32:$y_offset,
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$result,
I1:$pred
I1:$pred,
UnitAttr:$multicast
);

let assemblyFormat = [{
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1507,9 +1507,9 @@ LogicalResult GatherOp::inferReturnTypes(
}

// -- DescriptorGatherOp
static LogicalResult verifyGatherScatterResultType(Operation *op,
ShapedType resultType,
ShapedType indicesType) {
LogicalResult verifyGatherScatterResultType(Operation *op,
ShapedType resultType,
ShapedType indicesType) {
if (indicesType.getRank() != 1)
return op->emitOpError("x offsets must be a 1D tensor, but got ")
<< indicesType;
Expand Down
64 changes: 58 additions & 6 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,55 @@ static LogicalResult verifyAsyncTMAStoreOp(Operation *op,
return verifyTMAEncoding(op, desc.getType(), srcEnc);
}

static LogicalResult verifyAsyncTMAGatherScatterOp(Operation *op,
ShapedType blockType,
MemDescType memDescType,
ShapedType indicesType) {
if (blockType.getRank() != 2)
return op->emitOpError("descriptor block must be a 2D tensor, but got ")
<< blockType;
if (blockType.getShape()[0] != 1)
return op->emitOpError("descriptor block must have exactly 1 row, but got ")
<< blockType;
if (failed(verifyGatherScatterResultType(op, memDescType, indicesType)))
return failure();

if (memDescType.getShape()[1] != blockType.getShape()[1])
return op->emitOpError("result tensor number of columns must match block (")
<< blockType.getShape()[1] << "), but got " << memDescType;
if (memDescType.getElementType() != blockType.getElementType())
return op->emitOpError("result tensor element type must match block (")
<< blockType.getElementType() << "), but got " << memDescType;

ArrayRef<int64_t> allocShape = memDescType.getAllocShape();
if (allocShape.size() < 2 ||
memDescType.getShape() != allocShape.take_back(2))
return op->emitOpError("memdesc shape must match alloc shape");

auto xOffsetsType = cast<RankedTensorType>(indicesType);
if (xOffsetsType.getEncoding()) {
auto xCoordsLayout = triton::gpu::toLinearLayout(xOffsetsType);
auto kLane = StringAttr::get(op->getContext(), "lane");
if (getContigPerThread(xOffsetsType).front() < 4)
return op->emitOpError(
"x offsets must have at least 4 contiguous elements per thread");
unsigned threadsPerWarp = xCoordsLayout.getInDimSize(kLane);
if (xCoordsLayout.getFreeVariableMasks()[kLane] != (threadsPerWarp - 1))
return op->emitOpError("x offsets must be broadcasted across each warp");
auto kBlock = StringAttr::get(op->getContext(), "block");
auto kDim0 = StringAttr::get(op->getContext(), "dim0");
auto rowsCGA = getCGALayout(memDescType.getEncoding())
.getLinearLayout()
.sublayout({kBlock}, {kDim0});
auto xOffsetsCGA =
getCGALayout(xOffsetsType.getEncoding()).getLinearLayout();
if (rowsCGA != xOffsetsCGA)
return op->emitOpError(
"x offsets must have the same row CGA layout as the memdesc");
}
return success();
}

// Helper to determine if the descriptor type is for im2col mode
static bool isIm2ColDescriptor(Type descType) {
return isa<TensorDescIm2ColType>(descType);
Expand Down Expand Up @@ -546,9 +595,12 @@ LogicalResult AsyncTMAGatherOp::verify() {
// `tile::gather4` does not support fp4_padded operands.
if (isFp4Padded(getResult().getType().getEncoding()))
return emitOpError("does not support fp4_padded operands");
return verifyGatherScatterOp(*this,
getDesc().getType().getSignlessBlockType(),
resultType, getXOffsets().getType());
if (getMulticast() && !hasCGABroadcast(resultType))
return emitOpError(
"multicast requires the shared layout to broadcast across CTAs");
return verifyAsyncTMAGatherScatterOp(
*this, getDesc().getType().getSignlessBlockType(), resultType,
getXOffsets().getType());
}

Value AsyncTMAGatherOp::getPredicateOperand() { return getPred(); }
Expand All @@ -566,9 +618,9 @@ LogicalResult AsyncTMAScatterOp::verify() {
auto srcType = getSrc().getType();
if (failed(verifyAsyncTMAStoreOp(*this, getDesc(), srcType)))
return failure();
return verifyGatherScatterOp(*this,
getDesc().getType().getSignlessBlockType(),
srcType, getXOffsets().getType());
return verifyAsyncTMAGatherScatterOp(
*this, getDesc().getType().getSignlessBlockType(), srcType,
getXOffsets().getType());
}

// -- TCGen5MMAOp --
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ static bool isDistributedMultiCTAOp(Operation *op, bool isRead) {
return ttng::getModuleTwoCTAs(op);
} else if (auto tma = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
return tma.getMulticast();
} else if (auto tma = dyn_cast<ttng::AsyncTMAGatherOp>(op)) {
return tma.getMulticast();
}
return false;
}
Expand Down Expand Up @@ -109,6 +111,9 @@ usesTrackedBarrierInCrossCTAConsumerOp(Operation *op,
if (auto tma = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
return tma.getMulticast() && aliasesTracked(tma.getBarrier());
}
if (auto tma = dyn_cast<ttng::AsyncTMAGatherOp>(op)) {
return tma.getMulticast() && aliasesTracked(tma.getBarrier());
}
return false;
}

Expand Down
15 changes: 9 additions & 6 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,12 +943,15 @@ void init_gluon_ir(py::module &&m) {
[](GluonOpBuilder &self, int pendings) {
self.create<ttng::TMAStoreWaitOp>(pendings);
})
.def("create_async_tma_gather",
[](GluonOpBuilder &self, Value descPtr, Value xOffsets,
Value yOffset, Value barrier, Value result, Value pred) {
self.create<ttng::AsyncTMAGatherOp>(descPtr, xOffsets, yOffset,
barrier, result, pred);
})
.def(
"create_async_tma_gather",
[](GluonOpBuilder &self, Value descPtr, Value xOffsets, Value yOffset,
Value barrier, Value result, Value pred, bool multicast) {
multicast &=
ttng::hasCGABroadcast(cast<ttg::MemDescType>(result.getType()));
self.create<ttng::AsyncTMAGatherOp>(
descPtr, xOffsets, yOffset, barrier, result, pred, multicast);
})
.def("create_async_tma_scatter",
[](GluonOpBuilder &self, Value descPtr, Value xOffsets,
Value yOffset, Value src) {
Expand Down
143 changes: 130 additions & 13 deletions python/test/gluon/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mma_v2
from triton.experimental.gluon.language.nvidia.hopper import tma, mbarrier, fence_async_shared
from triton.experimental.gluon.language.nvidia import hopper
from triton.experimental.gluon.language.nvidia.blackwell import tma as blackwell_tma
from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy
from triton.experimental.gluon.language.extra import libdevice
from triton.experimental.gluon.language.nvidia.blackwell import (
Expand Down Expand Up @@ -262,6 +263,87 @@ def test_tma_multicast_copy(ctas_per_cga):
torch.testing.assert_close(out, inp, atol=0, rtol=0)


@gluon.jit
def tma_gather_scatter_kernel(in_desc, gather_out_desc, scatter_out_desc, gather_idx_ptr, scatter_idx_ptr,
BLOCK_M: ttgl.constexpr, x_offsets_layout: ttgl.constexpr):
smem = ttgl.allocate_shared_memory(in_desc.dtype, [BLOCK_M, gather_out_desc.block_shape[1]], gather_out_desc.layout)

bar = mbarrier.allocate_mbarrier()
Comment thread
lezcano marked this conversation as resolved.
mbarrier.init(bar, count=1)

gather_offsets = ttgl.load(gather_idx_ptr + ttgl.arange(0, BLOCK_M, layout=x_offsets_layout))
mbarrier.expect(bar, smem.nbytes_per_cta)
blackwell_tma.async_gather(in_desc, gather_offsets, 0, bar, smem, multicast=True)
mbarrier.wait(bar, phase=0, deps=[smem])

mbarrier.invalidate(bar)

scatter_offsets = ttgl.load(scatter_idx_ptr + ttgl.arange(0, BLOCK_M, layout=x_offsets_layout))
tma.async_copy_shared_to_global(gather_out_desc, [0, 0], smem)
blackwell_tma.async_scatter(scatter_out_desc, scatter_offsets, 0, smem)
tma.store_wait(0)

smem._keep_alive()


def get_split_dim(cga_layout, dim):
return 1 << sum(basis[dim] != 0 for basis in cga_layout)


@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
@pytest.mark.parametrize("cga_layout", [
[[1, 0]],
[[0, 0], [1, 0]],
[[1, 0], [0, 0]],
[[1, 0], [2, 0]],
])
def test_tma_gather_scatter_multi_cta(cga_layout):
cga_split_num = [get_split_dim(cga_layout, dim) for dim in range(2)]

BLOCK_M = 32 * cga_split_num[0]
BLOCK_N = 128 * cga_split_num[1]

inp = torch.arange(BLOCK_M * BLOCK_N, dtype=torch.float16, device="cuda").reshape(BLOCK_M, BLOCK_N)
gather_idx = torch.arange(BLOCK_M - 1, -1, -1, dtype=torch.int32, device="cuda")
scatter_idx = (torch.arange(0, BLOCK_M, dtype=torch.int32, device="cuda") + 1) % BLOCK_M
gather_out = torch.empty_like(inp)
scatter_out = torch.zeros_like(inp)

layout = ttgl.NVMMASharedLayout.get_default_for(
[BLOCK_M, BLOCK_N],
ttgl.float16,
cga_layout=cga_layout,
)
in_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(inp, [1, BLOCK_N // cga_split_num[1]], layout)
gather_out_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(gather_out, [BLOCK_M, BLOCK_N], layout)
scatter_out_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(scatter_out, [1, BLOCK_N // cga_split_num[1]],
layout)

offset_layout = ttgl.BlockedLayout([4, 1], [1, 32], [4, 1], [0, 1], cga_layout=cga_layout)
x_offsets_layout = ttgl.SliceLayout(1, offset_layout)

num_ctas = 1 << len(cga_layout)
compiled = tma_gather_scatter_kernel[(1, )](
in_desc,
gather_out_desc,
scatter_out_desc,
gather_idx,
scatter_idx,
BLOCK_M,
x_offsets_layout,
num_warps=4,
num_ctas=num_ctas,
)

expected_gather = inp[gather_idx.to(torch.int64)]
expected_scatter = torch.zeros_like(inp)
expected_scatter[scatter_idx.to(torch.int64)] = expected_gather
expect_multicast = any(all(coord == 0 for coord in basis) for basis in cga_layout)
assert (".multicast::cluster" in compiled.asm["ptx"]) == expect_multicast
torch.testing.assert_close(gather_out, expected_gather, atol=0, rtol=0)
torch.testing.assert_close(scatter_out, expected_scatter, atol=0, rtol=0)


@gluon.jit
def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr,
acc_tmem_layout: ttgl.constexpr, blocked_c: ttgl.constexpr):
Expand Down Expand Up @@ -655,11 +737,13 @@ def test_warpgroup_mma(ASYNC):


@gluon.jit
def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr,
BLOCK_K: ttgl.constexpr, NUM_K_TILES: ttgl.constexpr, block_layout_c: ttgl.constexpr,
def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, out_desc, gather_idx_ptr, scatter_idx_ptr,
BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr,
NUM_K_TILES: ttgl.constexpr, block_layout_c: ttgl.constexpr,
acc_layout: ttgl.constexpr, acc_tmem_layout: ttgl.constexpr,
use_tcgen05: ttgl.constexpr, multicast: ttgl.constexpr):
smem_a = ttgl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
use_tcgen05: ttgl.constexpr, multicast: ttgl.constexpr,
use_gather_scatter: ttgl.constexpr):
smem_a = ttgl.allocate_shared_memory(a_desc.dtype, [BLOCK_M, BLOCK_K], a_desc.layout)
smem_b = ttgl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)

two_ctas: ttgl.constexpr = isinstance(acc_tmem_layout, TensorMemoryLayout) and acc_tmem_layout.two_ctas
Expand All @@ -680,9 +764,17 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, BLOCK_M: ttgl.constexp
else:
acc = ttgl.zeros([BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=acc_layout)

if use_gather_scatter:
gather_offsets_layout: ttgl.constexpr = ttgl.SliceLayout(
1, ttgl.BlockedLayout([4, 1], [1, 32], [ttgl.num_warps(), 1], [0, 1], cga_layout=a_desc.layout.cga_layout))
gather_offsets = ttgl.load(gather_idx_ptr + ttgl.arange(0, BLOCK_M, layout=gather_offsets_layout))

for k in range(NUM_K_TILES):
mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, k * BLOCK_K], tma_bar, smem_a, multicast=multicast)
mbarrier.expect(tma_bar, smem_a.nbytes_per_cta + smem_b.nbytes_per_cta)
if use_gather_scatter:
blackwell_tma.async_gather(a_desc, gather_offsets, k * BLOCK_K, tma_bar, smem_a, multicast=multicast)
else:
tma.async_copy_global_to_shared(a_desc, [0, k * BLOCK_K], tma_bar, smem_a, multicast=multicast)
tma.async_copy_global_to_shared(b_desc, [k * BLOCK_K, 0], tma_bar, smem_b, multicast=multicast)
mbarrier.wait(tma_bar, phase=phase_tma, deps=[smem_a, smem_b])
phase_tma ^= 1
Expand All @@ -705,9 +797,19 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, BLOCK_M: ttgl.constexp
acc = acc_tmem.load()

acc = ttgl.convert_layout(acc, block_layout_c)
offs_m = ttgl.arange(0, BLOCK_M)[:, None]
offs_n = ttgl.arange(0, BLOCK_N)[None, :]
ttgl.store(out_ptr + offs_m * BLOCK_N + offs_n, acc)
if use_gather_scatter:
scatter_offsets_layout: ttgl.constexpr = ttgl.SliceLayout(
1, ttgl.BlockedLayout([4, 1], [1, 32], [ttgl.num_warps(), 1], [0, 1],
cga_layout=out_desc.layout.cga_layout))
scatter_offsets = ttgl.load(scatter_idx_ptr + ttgl.arange(0, BLOCK_M, layout=scatter_offsets_layout))
acc_smem = ttgl.allocate_shared_memory(out_desc.dtype, [BLOCK_M, BLOCK_N], out_desc.layout, acc)
blackwell_tma.async_scatter(out_desc, scatter_offsets, 0, acc_smem)
tma.store_wait(0)
acc_smem._keep_alive()
else:
offs_m = ttgl.arange(0, BLOCK_M)[:, None]
offs_n = ttgl.arange(0, BLOCK_N)[None, :]
ttgl.store(out_ptr + offs_m * BLOCK_N + offs_n, acc)


@pytest.mark.skipif(not (is_hopper() or is_blackwell()), reason="Requires Hopper or Blackwell")
Expand All @@ -716,7 +818,8 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, BLOCK_M: ttgl.constexp
@pytest.mark.parametrize("ctas_per_cga", [[1, 1], [2, 1], [4, 4]])
@pytest.mark.parametrize("two_ctas", [False, True] if is_blackwell() else [False])
@pytest.mark.parametrize("multicast", [False, True])
def test_tma_mma_shared_inputs(warps, reps, ctas_per_cga, two_ctas, multicast):
@pytest.mark.parametrize("use_gather_scatter", [False, True] if is_blackwell() else [False])
def test_tma_mma_shared_inputs(warps, reps, ctas_per_cga, two_ctas, multicast, use_gather_scatter):
bitwidth = 16
acc_dtype = torch.float32

Expand Down Expand Up @@ -788,19 +891,26 @@ def cast(x, dtype):
gluon_dtype = ttgl.float16
shared_layout_a = ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gluon_dtype, cga_layout=cga_layout_a)
shared_layout_b = ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gluon_dtype, cga_layout=cga_layout_b)
shared_layout_c = ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], ttgl.float32, cga_layout=cga_layout_c)
assert shared_layout_a.swizzle_byte_width != 0
assert shared_layout_b.swizzle_byte_width != 0
a_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(a, [BLOCK_M, BLOCK_K], shared_layout_a)
a_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(a, [1 if use_gather_scatter else BLOCK_M, BLOCK_K],
shared_layout_a)
b_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(b, [BLOCK_K, BLOCK_N], shared_layout_b)

num_warps = warps[0] * warps[1]
num_ctas = ctas_per_cga[0] * ctas_per_cga[1]
out_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(out, [1, BLOCK_N], shared_layout_c)
gather_idx = torch.arange(BLOCK_M - 1, -1, -1, dtype=torch.int32, device=device)
scatter_idx = (torch.arange(0, BLOCK_M, dtype=torch.int32, device=device) + 1) % BLOCK_M

try:
tma_mma_shared_inputs_kernel[(1, )](
a_desc,
b_desc,
out,
out_desc,
gather_idx,
scatter_idx,
BLOCK_M,
BLOCK_N,
BLOCK_K,
Expand All @@ -810,6 +920,7 @@ def cast(x, dtype):
acc_tmem_layout,
is_blackwell(),
multicast=multicast,
use_gather_scatter=use_gather_scatter,
num_warps=num_warps,
num_ctas=num_ctas,
)
Expand All @@ -819,7 +930,13 @@ def cast(x, dtype):
try:
allow_tf32 = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = True
ref = torch.matmul(a.to(torch.float32), b.to(torch.float32))
if use_gather_scatter:
matmul = torch.matmul(a[gather_idx.to(torch.int64)].to(torch.float32), b.to(torch.float32))
ref = torch.empty_like(matmul)
# Correct as scatter_idx is a permutation!
ref[scatter_idx.to(torch.int64)] = matmul
else:
ref = torch.matmul(a.to(torch.float32), b.to(torch.float32))
finally:
torch.backends.cuda.matmul.allow_tf32 = allow_tf32

Expand Down
Loading