Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def TTNG_CLCTryCancelOp : TTNG_Op<"clc_try_cancel", [

let arguments = (ins
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$result,
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$mbarrier,
I1Attr:$multicast
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$mbarrier
);

let assemblyFormat = [{
Expand Down
2 changes: 1 addition & 1 deletion python/examples/gluon/03-matmul-multicta.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def matmul_clc_partition(p):
result = p.clc_result_buffers.index(state.index)
# 16: clc.try_cancel has a `.b128` modifier
mbarrier.expect(barrier, 16)
clc.try_cancel(result, barrier, multicast=True)
clc.try_cancel(result, barrier)
mbarrier.wait(barrier, state.phase)
clc_res = clc.load_result(result)
has_work = clc_res.is_canceled()
Expand Down
2 changes: 1 addition & 1 deletion python/examples/gluon/04-2cta-block-scale-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def mma_scaled_clc_partition(p):
result = p.clc_result_buffers.index(state.index)
# 16: clc.try_cancel has a `.b128` modifier
mbarrier.expect(barrier, 16)
clc.try_cancel(result, barrier, multicast=True)
clc.try_cancel(result, barrier)
mbarrier.wait(barrier, state.phase)
clc_res = clc.load_result(result)
has_work = clc_res.is_canceled()
Expand Down
5 changes: 2 additions & 3 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -870,9 +870,8 @@ void init_gluon_ir(py::module &&m) {
py::arg("relaxed") = false)
// CLC (Cluster Launch Control) ops - SM100+
.def("create_clc_try_cancel",
[](GluonOpBuilder &self, Value result, Value mbarrier,
bool multicast) {
self.create<ttng::CLCTryCancelOp>(result, mbarrier, multicast);
[](GluonOpBuilder &self, Value result, Value mbarrier) {
self.create<ttng::CLCTryCancelOp>(result, mbarrier);
})
.def("create_clc_load_result",
[](GluonOpBuilder &self, Value result) -> Value {
Expand Down
2 changes: 1 addition & 1 deletion python/test/gluon/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3529,7 +3529,7 @@ def clc_kernel(WasLaunched, IsCancelled, ProgramId, smem_size: ttgl.constexpr):
# Large shared memory allocation to force 1 block per SM
dummy = ttgl.allocate_shared_memory(ttgl.int64, [smem_size // 8 - 32], clc_mbar.layout)

clc.try_cancel(clc_result, clc_mbar, multicast=True)
clc.try_cancel(clc_result, clc_mbar)
mbarrier.expect(clc_mbar, 16)
mbarrier.wait(clc_mbar, 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@


@builtin
def try_cancel(result: shared_memory_descriptor, barrier, multicast=False, _semantic=None):
def try_cancel(result: shared_memory_descriptor, barrier, _semantic=None):
"""
Issue a CLC try_cancel request to atomically cancel a pending cluster launch.

Args:
result (shared_memory_descriptor): 16-byte aligned int64x2 shared memory for the response
barrier (shared_memory_descriptor): 8-byte aligned mbarrier for completion signaling
multicast (bool): If True, broadcast result to all CTAs in cluster

Only supported on SM100+ (Blackwell).
"""
_semantic.builder.create_clc_try_cancel(result.handle, barrier.handle, multicast)
_semantic.builder.create_clc_try_cancel(result.handle, barrier.handle)


@builtin
Expand Down
2 changes: 1 addition & 1 deletion python/tutorials/gluon/12-cluster-launch-control.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def initialize(M, N, BLOCK_M, BLOCK_N):

@gluon.jit
def try_cancel(self) -> None:
clc.try_cancel(self.clc_result_buf, self.barrier, multicast=True)
clc.try_cancel(self.clc_result_buf, self.barrier)
mbarrier.expect(self.barrier, 16)

@gluon.jit
Expand Down
2 changes: 1 addition & 1 deletion python/tutorials/gluon/14-multicta.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def matmul_clc_partition(p):
barrier = p.clc_barriers.index(state.index)
result = p.clc_result_buffers.index(state.index)
mbarrier.expect(barrier, 16)
clc.try_cancel(result, barrier, multicast=True)
clc.try_cancel(result, barrier)
mbarrier.wait(barrier, state.phase)
clc_res = clc.load_result(result)
has_work = clc_res.is_canceled()
Expand Down
16 changes: 15 additions & 1 deletion test/Conversion/clc_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: clc_try_cancel
tt.func @clc_try_cancel(%result: !ttg.memdesc<2xi64, #shared0, #smem>, %mbar: !ttg.memdesc<1xi64, #shared0, #smem>) {
// CHECK: clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128
ttng.clc_try_cancel %result, %mbar {multicast = false} : !ttg.memdesc<2xi64, #shared0, #smem>, !ttg.memdesc<1xi64, #shared0, #smem>
ttng.clc_try_cancel %result, %mbar : !ttg.memdesc<2xi64, #shared0, #smem>, !ttg.memdesc<1xi64, #shared0, #smem>
tt.return
}
}

// -----

#shared_clc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
#barrier = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: clc_try_cancel_multicast
tt.func @clc_try_cancel_multicast(%result: !ttg.memdesc<2xi64, #shared_clc, #smem>, %mbar: !ttg.memdesc<1xi64, #barrier, #smem>) {
// CHECK: clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128
ttng.clc_try_cancel %result, %mbar : !ttg.memdesc<2xi64, #shared_clc, #smem>, !ttg.memdesc<1xi64, #barrier, #smem>
tt.return
}
}
Expand Down
4 changes: 2 additions & 2 deletions test/TritonNvidiaGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
%result: !ttg.memdesc<2xi64, #shared_clc, #smem>,
%mbar: !ttg.memdesc<1xi64, #barrier, #smem>) {
// expected-error @below {{completion barrier cga_layout must be}}
ttng.clc_try_cancel %result, %mbar {multicast = false} :
ttng.clc_try_cancel %result, %mbar :
!ttg.memdesc<2xi64, #shared_clc, #smem>, !ttg.memdesc<1xi64, #barrier, #smem>
tt.return
}
Expand All @@ -394,7 +394,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
%result: !ttg.memdesc<2xi64, #shared_clc_bad, #smem>,
%mbar: !ttg.memdesc<1xi64, #barrier, #smem>) {
// expected-error @below {{Expected CLC result buffer cga_layout bases to be all zeros. Got [[1]]}}
ttng.clc_try_cancel %result, %mbar {multicast = false} :
ttng.clc_try_cancel %result, %mbar :
!ttg.memdesc<2xi64, #shared_clc_bad, #smem>, !ttg.memdesc<1xi64, #barrier, #smem>
tt.return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ struct CLCTryCancelOpConversion

std::string ptxAsm = "@$2 clusterlaunchcontrol.try_cancel.async.shared::cta"
".mbarrier::complete_tx::bytes";
if (op.getMulticast())
if (numCTAs > 1)
ptxAsm += ".multicast::cluster::all";
ptxAsm += ".b128 [$0], [$1];";

Expand Down
2 changes: 1 addition & 1 deletion third_party/proton/test/test_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def gluon_clc_vector_add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: g
y = gl.load(y_ptr + offsets, mask)
gl.store(out_ptr + offsets, x + y, mask)

clc.try_cancel(clc_result, clc_bar, multicast=True)
clc.try_cancel(clc_result, clc_bar)
mbarrier.expect(clc_bar, 16)
mbarrier.wait(clc_bar, phase)

Expand Down
Loading