Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 58 additions & 28 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,62 @@ LogicalResult ClusterBarrierOp::verify() {
}

// -- TMA operation verifiers --
static std::string formatCGALayout(CGAEncodingAttr cgaLayout) {
std::string str;
llvm::raw_string_ostream os(str);
auto kBlock = StringAttr::get(cgaLayout.getContext(), "block");
os << "[";
llvm::interleaveComma(cgaLayout.getLinearLayout().getBases().lookup(kBlock),
os, [&](const auto &basis) {
os << "[";
llvm::interleaveComma(basis, os);
os << "]";
});
os << "]";
return os.str();
}

static LogicalResult verifyBarrierCGALayout(Operation *op, Value barrier,
CGAEncodingAttr expectedCGALayout,
StringRef barrierName) {
auto barrierTy = cast<MemDescType>(barrier.getType());
auto actualCGALayout = getCGALayout(barrierTy.getEncoding());
if (actualCGALayout != expectedCGALayout)
return op->emitOpError() << barrierName << " cga_layout must be "
<< formatCGALayout(expectedCGALayout) << ", got "
<< formatCGALayout(actualCGALayout);
return success();
}

static LogicalResult verifyCompletionBarrierLayout(Operation *op,
Value barrier) {
auto expectedCGALayout =
CGAEncodingAttr::get1DLayout(op->getContext(), gpu::lookupNumCTAs(op));
return verifyBarrierCGALayout(op, barrier, expectedCGALayout,
"completion barrier");
}

static LogicalResult verifyTMABarrierLayout(Operation *op, Value barrier) {
auto twoCTAsAttr =
op->getParentOfType<ModuleOp>()->getAttrOfType<BoolAttr>(AttrTwoCTAsName);
if (!twoCTAsAttr)
return success();

auto ctx = op->getContext();
int numCTAs = gpu::lookupNumCTAs(op);
CGAEncodingAttr expectedCGALayout;
if (twoCTAsAttr.getValue()) {
auto kBlock = StringAttr::get(ctx, "block");
auto dim = standardOutDimNames(ctx, /*rank=*/1)[0];
auto layout = LinearLayout::zeros1D(2, kBlock, dim) *
LinearLayout::identity1D(numCTAs / 2, kBlock, dim);
expectedCGALayout = CGAEncodingAttr::get(ctx, std::move(layout));
} else {
expectedCGALayout = CGAEncodingAttr::get1DLayout(ctx, numCTAs);
}
return verifyBarrierCGALayout(op, barrier, expectedCGALayout, "TMA barrier");
}

static LogicalResult verifyTMAEncoding(Operation *op, TensorDescInterface desc,
Attribute enc) {
auto nvmma = dyn_cast<NVMMASharedEncodingAttr>(enc);
Expand Down Expand Up @@ -318,6 +374,8 @@ static LogicalResult verifyAsyncTMALoadOp(Operation *op,
MemDescType resultType) {
if (failed(verifyBarrierType(op, barrier.getType())))
return failure();
if (failed(verifyTMABarrierLayout(op, barrier)))
return failure();
if (!resultType.getMutableMemory())
return op->emitOpError("cannot store into immutable memory");
if (failed(verifyTMAEncoding(op, desc, resultType.getEncoding())))
Expand Down Expand Up @@ -599,34 +657,6 @@ static LogicalResult verifyMMADType(Operation *op, Type a, Type b, Type d) {
return success();
}

static std::string formatCGALayout(CGAEncodingAttr cgaLayout) {
std::string str;
llvm::raw_string_ostream os(str);
auto kBlock = StringAttr::get(cgaLayout.getContext(), "block");
os << "[";
llvm::interleaveComma(cgaLayout.getLinearLayout().getBases().lookup(kBlock),
os, [&](const auto &basis) {
os << "[";
llvm::interleaveComma(basis, os);
os << "]";
});
os << "]";
return os.str();
}

static LogicalResult verifyCompletionBarrierLayout(Operation *op,
Value barrier) {
auto barrierTy = cast<MemDescType>(barrier.getType());
auto expectedCGALayout =
CGAEncodingAttr::get1DLayout(op->getContext(), gpu::lookupNumCTAs(op));
auto actualCGALayout = getCGALayout(barrierTy.getEncoding());
if (actualCGALayout != expectedCGALayout)
return op->emitOpError("completion barrier cga_layout must be ")
<< formatCGALayout(expectedCGALayout) << ", got "
<< formatCGALayout(actualCGALayout);
return success();
}

LogicalResult TCGen5MMAOp::verify() {
if (!getIsAsync() && !getBarriers().empty()) {
return emitOpError("The op is synchronous but a barrier is present.");
Expand Down
98 changes: 24 additions & 74 deletions python/test/gluon/test_consan.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,10 @@ def test_consan_uses_profile_scratch(device, fresh_knobs, num_ctas):

@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
@pytest.mark.parametrize("FAILURE", [True, False])
@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True])
def test_async_tma_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas):
if TWO_CTA_BARRIER and num_ctas == 1:
pytest.skip("Need at least 2 CTAs for a two-CTA barrier")
def test_async_tma_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas):
if run_wrapper:
result = run_in_process(test_async_tma_kernel, (FAILURE, TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas))
if FAILURE or TWO_CTA_BARRIER:
result = run_in_process(test_async_tma_kernel, (FAILURE, device, False, monkeypatch, num_ctas))
if FAILURE:
assert_expected_cuda_failure(result.exc)
assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output
else:
Expand All @@ -177,13 +174,13 @@ def test_async_tma_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeyp
knobs.refresh_knobs()

@gluon.jit
def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.constexpr):
def kernel(input_desc, out, FAILURE: ttgl.constexpr):
block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas()
cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2)
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1],
warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout)
smem = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout)
bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER)
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
Expand All @@ -203,19 +200,17 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.const
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2,
cga_layout=default_cga_layout(num_ctas, 2))
input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [block_m, XBLOCK.value], shared_layout)
kernel[(1, )](input_desc, output, FAILURE=FAILURE, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas)
kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas)


@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
@pytest.mark.parametrize("FAILURE", [True, False])
@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True])
def test_async_tma_multicast_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas):
def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas):
if num_ctas == 1:
pytest.skip("Need at least 2 CTAs for multicast in this test")
if run_wrapper:
result = run_in_process(test_async_tma_multicast_kernel,
(FAILURE, TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas))
if FAILURE or TWO_CTA_BARRIER:
result = run_in_process(test_async_tma_multicast_kernel, (FAILURE, device, False, monkeypatch, num_ctas))
if FAILURE:
assert_expected_cuda_failure(result.exc)
assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output
else:
Expand All @@ -228,12 +223,12 @@ def test_async_tma_multicast_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrappe
knobs.refresh_knobs()

@gluon.jit
def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.constexpr):
def kernel(input_desc, out, FAILURE: ttgl.constexpr):
cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 2)
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1],
warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout)
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER)
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True)
Expand All @@ -252,54 +247,7 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.const
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2,
cga_layout=multicast_cga_layout(num_ctas, 2))
input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [XBLOCK.value, XBLOCK.value], shared_layout)
kernel[(1, )](input_desc, output, FAILURE=FAILURE, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas)


@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True])
def test_async_tma_multicast_kernel_two_cta_barrier(TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas):
if num_ctas != 2:
pytest.skip("This test covers a single 2-CTA multicast group")
if run_wrapper:
result = run_in_process(test_async_tma_multicast_kernel_two_cta_barrier,
(TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas))
if TWO_CTA_BARRIER:
assert_expected_cuda_failure(result.exc)
assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output
else:
assert result.exc is None
assert result.driver_stderr_output == ""
return

monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
knobs.refresh_knobs()

@gluon.jit
def kernel(input_desc, out, TWO_CTA_BARRIER: ttgl.constexpr):
cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 2)
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1],
warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout)
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER)
mbarrier.init(bar, count=1)
mbarrier.expect(bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True)
mbarrier.wait(bar, 0, deps=[smem])
val = smem.load(blocked_layout)
mbarrier.invalidate(bar)

out_m = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(1, blocked_layout))[:, None]
out_n = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, blocked_layout))[None, :]
out_ptr = out + out_m * XBLOCK + out_n
ttgl.store(out_ptr, val)

input = torch.randn((XBLOCK.value, XBLOCK.value), device=device, dtype=torch.float16)
output = torch.empty((XBLOCK.value, XBLOCK.value), device=device, dtype=torch.float16)
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2,
cga_layout=multicast_cga_layout(num_ctas, 2))
input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [XBLOCK.value, XBLOCK.value], shared_layout)
kernel[(1, )](input_desc, output, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas)
kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas)


@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
Expand Down Expand Up @@ -741,21 +689,24 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt
ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16,
cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, TWO_CTAS)),
)
bar = mbarrier.allocate_mbarrier(batch=2)
mma_bar = mbarrier.allocate_mbarrier()
acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout)
mbarrier.init(bar.index(0), count=1)
mbarrier.init(bar.index(1), count=1)
mbarrier.init(mma_bar, count=1)
if MEM_ACCESS_KIND == "tma_cp":
tma_bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTAS)
mbarrier.init(tma_bar, count=1)

blackwell.tcgen05_mma(smemA, smemB, acc)
blackwell.tcgen05_commit(bar.index(0))
blackwell.tcgen05_commit(mma_bar)

if not FAILURE:
mbarrier.wait(bar.index(0), 0)
mbarrier.wait(mma_bar, 0)

if MEM_ACCESS_KIND == "tma_cp":
mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smemA)
mbarrier.wait(bar.index(1), 0)
mbarrier.expect(tma_bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], tma_bar, smemA)
mbarrier.wait(tma_bar, 0)
mbarrier.invalidate(tma_bar)
elif MEM_ACCESS_KIND == "local_store":
smemA.store(ttgl.full([block_m, XBLOCK], 42, ttgl.float16, smem_a_blocked_layout))
elif MEM_ACCESS_KIND == "tmem_load":
Expand All @@ -769,8 +720,7 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt
elif MEM_ACCESS_KIND == "tmem_store":
acc.store(ttgl.full([block_m, block_n], 42, ttgl.float32, acc_blocked_layout))

mbarrier.invalidate(bar.index(0))
mbarrier.invalidate(bar.index(1))
mbarrier.invalidate(mma_bar)

block_m = mma_block_m(num_ctas)
block_n = mma_block_n(num_ctas)
Expand Down
6 changes: 4 additions & 2 deletions python/tutorials/gluon/14-multicta.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,9 @@ def broadcast(b):
# every CTA in the multicast group atomically, so the wait side does not need a
# different API.
#
# The only new ingredient is the layout. The TMA destination must use a
# broadcast `cga_layout`, so that both CTAs view the same shared-memory tile.
# The TMA destination must use a broadcast `cga_layout`, so that both CTAs
# receive the same shared-memory tile. The barrier stays a regular 1D TMA
# barrier unless the kernel is in 2CTA mode.
#
# The example below keeps things intentionally simple: it multicasts one tile
# into shared memory and then materializes that same tile back to global memory.
Expand All @@ -489,6 +490,7 @@ def tma_multicast_copy_kernel(in_desc, out_desc):
gl.static_assert(gl.num_ctas() == 2)

smem = gl.allocate_shared_memory(in_desc.dtype, in_desc.block_shape, in_desc.layout)
# This kernel is not in 2CTA mode, so the TMA barrier is per-CTA.
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)

Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
#shared0_cluster = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
#shared1_cga = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CGALayout = [[0, 0]]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true} {
// CHECK-LABEL: tma_copy_barrier_mask_nonzero
// Barrier pointer is modified when barrier mask != 0
// CHECK: llvm.ptrtoint
// CHECK: llvm.and
// CHECK: llvm.inttoptr
// TMA uses shared::cluster when barrier mask is non-zero
// CHECK: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
// CHECK: cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes
// CHECK-NOT: cp.async.bulk.tensor.2d.shared::cta.global.mbarrier
tt.func @tma_copy_barrier_mask_nonzero(%tma: !tt.tensordesc<128x128xf32, #shared1_cga>, %alloc: !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_cluster, #smem>, %pred: i1) {
ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc<128x128xf32, #shared1_cga>, !ttg.memdesc<1xi64, #shared0_cluster, #smem> -> !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable>
Expand Down
4 changes: 2 additions & 2 deletions test/TritonGPU/consan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
// CHECK: arith.shrui
// CHECK-LABEL: @outstanding_commits_multicast_tma_recipients
tt.func public @outstanding_commits_multicast_tma_recipients(
%desc: !tt.tensordesc<tensor<32x32xf32, #shared>>,
%desc: !tt.tensordesc<32x32xf32, #shared>,
%ptr: tensor<32x32x!tt.ptr<f32>, #blocked>) {
%true = arith.constant true
%c0_i32 = arith.constant 0 : i32
Expand All @@ -286,7 +286,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
// CHECK: %[[RECIPIENTS:.*]] = arith.shli %[[PATTERN]],
// CHECK: tt.call @__triton_consan_check_outstanding_commits{{.*}}({{.*}}, %[[RECIPIENTS]])
// CHECK: ttng.async_tma_copy_global_to_local
ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %shmem, %bar, %true {multicast} : !tt.tensordesc<tensor<32x32xf32, #shared>>, !ttg.memdesc<2xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %shmem, %bar, %true {multicast} : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<2xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
tt.return
}
}
Expand Down
Loading
Loading