Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,14 @@ TLX uses **CUDA-native cluster semantics** which differs from Triton's approach:
y = tlx.stoch_round(x, tlx.dtype_of(y_ptr), rbits)
```

- `tlx.vote_ballot_sync(mask, pred)`

Collects a predicate from each thread in the warp and returns a 32-bit
mask where each bit represents the predicate value from the corresponding
lane. Only threads specified by `mask` participate in the vote.
```
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)
```

## Kernels Implemented with TLX

Expand Down
39 changes: 39 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,45 @@ def TTNG_CLCQueryCancelOp : TTNG_Op<"clc_query_cancel", []> {
let assemblyFormat = "$clcResAlloc attr-dict `:` functional-type(operands, $ctaId)";
}

def TTNG_VoteBallotSyncOp : TTNG_Op<"vote_ballot_sync", [Pure]> {
let summary = "Warp-level vote ballot synchronization";

let description = [{
Performs a warp-level vote ballot operation that collects a predicate from
each thread in the warp and returns a 32-bit mask where each bit represents
the predicate value from the corresponding lane.

The `mask` operand specifies which threads participate in the vote. Threads
with their corresponding bit set in the mask must execute the instruction
with the same mask value.

The `pred` operand can be either:
- A scalar i1: Each thread contributes this predicate, returns scalar i32
- A tensor of i1: Each thread contributes its element(s), returns tensor of i32
with the same shape. All threads in a warp receive the same ballot value.

When pred is a tensor, each thread contributes the OR of all its owned
elements to the ballot. The result tensor has the same shape, with each
element containing the warp's ballot result.

This lowers to PTX instruction:
vote.sync.ballot.b32 dest, predicate, membermask;

https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync
}];

let arguments = (ins
I32:$mask,
AnyTypeOf<[I1, TT_BoolTensor]>:$pred
);

let results = (outs AnyTypeOf<[I32, TT_IntTensor]>:$result);

let assemblyFormat = "$mask `,` $pred attr-dict `:` type($pred) `->` type($result)";

let hasVerifier = 1;
}

def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [AttrSizedOperandSegments]> {
let summary = "copy data based on descriptor from global memory to local memory asynchronously";

Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
triton::gpu::AsyncCopyGlobalToLocalOp, triton::gpu::LocalLoadOp,
triton::gpu::LocalStoreOp, triton::gpu::RemoteShmemStoreOp,
triton::gpu::AsyncRemoteShmemStoreOp,
triton::nvidia_gpu::WarpGroupDotWaitOp, triton::tlx::RequireLayoutOp,
triton::nvidia_gpu::WarpGroupDotWaitOp,
triton::nvidia_gpu::VoteBallotSyncOp, triton::tlx::RequireLayoutOp,
triton::tlx::ReleaseLayoutOp, triton::tlx::LocalAliasOp>(
[&](Operation *op) -> bool {
// make sure every RankedTensorType operand has encoding
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::gpu::AsyncRemoteShmemStoreOp>,
GenericOpPattern<triton::gpu::LocalLoadOp>,
GenericOpPattern<triton::nvidia_gpu::WarpGroupDotWaitOp>,
GenericOpPattern<triton::nvidia_gpu::VoteBallotSyncOp>,
TritonFuncOpPattern>(typeConverter, context);
}

Expand Down
56 changes: 56 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,62 @@ LogicalResult ArriveBarrierOp::verify() {
return success();
}

// -- VoteBallotSyncOp --
LogicalResult VoteBallotSyncOp::verify() {
Type predType = getPred().getType();
Type resultType = getResult().getType();

bool predIsTensor = isa<RankedTensorType>(predType);
bool resultIsTensor = isa<RankedTensorType>(resultType);

// Both must be scalars or both must be tensors
if (predIsTensor != resultIsTensor) {
return emitOpError("predicate and result must both be scalars or both be "
"tensors, got pred=")
<< predType << " and result=" << resultType;
}

if (predIsTensor) {
auto predTensorType = cast<RankedTensorType>(predType);
auto resultTensorType = cast<RankedTensorType>(resultType);

// Check element types
if (!predTensorType.getElementType().isInteger(1)) {
return emitOpError("tensor predicate must have i1 element type, got ")
<< predTensorType.getElementType();
}
if (!resultTensorType.getElementType().isInteger(32)) {
return emitOpError("tensor result must have i32 element type, got ")
<< resultTensorType.getElementType();
}

// Shapes must match
if (predTensorType.getShape() != resultTensorType.getShape()) {
return emitOpError("predicate and result tensor shapes must match, got ")
<< predTensorType.getShape() << " vs "
<< resultTensorType.getShape();
}

// Encodings must match (if present)
if (predTensorType.getEncoding() != resultTensorType.getEncoding()) {
return emitOpError(
"predicate and result tensor encodings must match, got ")
<< predTensorType.getEncoding() << " vs "
<< resultTensorType.getEncoding();
}
} else {
// Scalar case
if (!predType.isInteger(1)) {
return emitOpError("scalar predicate must be i1, got ") << predType;
}
if (!resultType.isInteger(32)) {
return emitOpError("scalar result must be i32, got ") << resultType;
}
}

return success();
}

// -- AsyncTMACopyGlobalToLocalOp --
LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
if (failed(verifyBarrierType(*this, getBarrier().getType())))
Expand Down
64 changes: 62 additions & 2 deletions python/test/unit/language/test_tlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2633,8 +2633,8 @@ def descriptor_load_kernel(input_ptr, output_ptr, M, N, BLOCK_SIZE_M: tl.constex
kernel = descriptor_load_kernel[grid](x, y, M, N, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,
ctas_per_cga=(2, 2, 1))

assert kernel.asm["ptx"].count(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster") == 1
assert (kernel.asm["ptx"].count(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster") == 1)
# x:
# [ x0 | x2]
# [ x1 | x3]
Expand Down Expand Up @@ -4550,3 +4550,63 @@ def test_reuse_storage_mismatch_error_message(self):
# We can't fully test the error without a kernel context, but we can
# verify the storage_alias_spec's storage property is accessible
assert buf.storage == tlx.storage_kind.smem


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
def test_vote_ballot_sync(device):
"""Test vote_ballot_sync TLX operation for warp-level voting."""

@triton.jit
def vote_ballot_kernel(
output_ptr,
BLOCK_SIZE: tl.constexpr,
):
# Each thread's lane ID (use x-axis thread ID)
tid = tlx.thread_id(0)

# Create a predicate: lanes 0-15 vote True, lanes 16-31 vote False
pred = tid < 16

# Perform warp-level ballot vote
# 0xFFFFFFFF means all 32 threads in the warp participate
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)

# Store the ballot result from thread 0 only
if tid == 0:
tl.store(output_ptr, ballot_result)

output = torch.zeros(1, dtype=torch.int32, device=device)

# Run the kernel with 1 warp
vote_ballot_kernel[(1, )](output, BLOCK_SIZE=32, num_warps=1)
torch.cuda.synchronize()

# Expected ballot result: threads 0-15 have pred=True, threads 16-31 have pred=False
# So ballot should be 0x0000FFFF (lower 16 bits set)
expected_ballot = 0x0000FFFF
assert output.item() == expected_ballot, f"Expected {hex(expected_ballot)}, got {hex(output.item())}"


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
def test_vote_ballot_sync_ir_emission(device):
"""Test that vote_ballot_sync generates the correct IR."""

@triton.jit
def vote_ballot_ir_kernel(output_ptr, ):
tid = tlx.thread_id(0)
pred = tid < 16 # First 16 threads True
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)
if tid == 0:
tl.store(output_ptr, ballot_result)

output = torch.zeros(1, dtype=torch.int32, device=device)
kernel = vote_ballot_ir_kernel[(1, )](output, num_warps=1)

# Verify the TTGIR contains the vote_ballot_sync op
ttgir = kernel.asm["ttgir"]
assert "vote_ballot_sync" in ttgir, "Expected vote_ballot_sync in TTGIR"

# Verify the LLVM IR contains the NVVM vote instruction
llir = kernel.asm["llir"]
assert "nvvm.vote.ballot.sync" in llir or "vote.sync.ballot" in llir, (
"Expected nvvm.vote.ballot.sync or vote.sync.ballot in LLVM IR")
126 changes: 126 additions & 0 deletions test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,99 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

// -----

// Test that tensor select with warp-uniform condition (from vote_ballot via splat)
// is converted to branches instead of per-element select instructions.
// This is the pattern used in Flash Attention for conditional rescaling.
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: uniform_tensor_select_to_branch
// CHECK: nvvm.vote.sync ballot
// CHECK: llvm.icmp "ne"
// CHECK: llvm.cond_br
// CHECK: llvm.br
// CHECK: llvm.br
tt.func @uniform_tensor_select_to_branch(%mask: i32, %pred: i1, %true_val: tensor<16x32xf32, #blocked>, %false_val: tensor<16x32xf32, #blocked>, %ptr: !tt.ptr<f32>) {
// Get warp-uniform ballot result
%ballot = ttng.vote_ballot_sync %mask, %pred : i1 -> i32
%c0 = arith.constant 0 : i32
// Compare ballot result (scalar i32) - this is warp-uniform
%scalar_cond = arith.cmpi ne, %ballot, %c0 : i32
// Splat scalar condition to tensor shape to match tensor operands
%cond = tt.splat %scalar_cond : i1 -> tensor<16x32xi1, #blocked>
// Select with uniform tensor condition - should become branches
%result = arith.select %cond, %true_val, %false_val : tensor<16x32xi1, #blocked>, tensor<16x32xf32, #blocked>
// Store result (kernels can't return values)
%ptrs = tt.splat %ptr : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked>
tt.store %ptrs, %result : tensor<16x32x!tt.ptr<f32>, #blocked>
tt.return
}
}

// -----

// Test the full Flash Attention pattern: tensor predicate -> vote_ballot -> tensor condition -> select
// This matches the actual FA kernel pattern where pred = alpha_1 < 1.0 is a tensor.
#blocked1d = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: uniform_tensor_select_tensor_pred
// CHECK: nvvm.vote.sync ballot
// CHECK: llvm.icmp "ne"
// CHECK: llvm.cond_br
// CHECK: llvm.br
// CHECK: llvm.br
tt.func @uniform_tensor_select_tensor_pred(%mask: i32, %alpha: tensor<128xf32, #blocked1d>, %acc: tensor<128xf32, #blocked1d>, %scaled_acc: tensor<128xf32, #blocked1d>, %ptr: !tt.ptr<f32>) {
// pred = alpha < 1.0 - this is a tensor predicate
%c1 = arith.constant dense<1.0> : tensor<128xf32, #blocked1d>
%pred = arith.cmpf olt, %alpha, %c1 : tensor<128xf32, #blocked1d>
// ballot_result is a tensor with the same shape, all elements contain warp ballot
%ballot = ttng.vote_ballot_sync %mask, %pred : tensor<128xi1, #blocked1d> -> tensor<128xi32, #blocked1d>
// should_rescale = ballot_result != 0
%c0 = arith.constant dense<0> : tensor<128xi32, #blocked1d>
%should_rescale = arith.cmpi ne, %ballot, %c0 : tensor<128xi32, #blocked1d>
// Conditional select - condition is uniform since ballot result is same for all threads in warp
%result = arith.select %should_rescale, %scaled_acc, %acc : tensor<128xi1, #blocked1d>, tensor<128xf32, #blocked1d>
// Store result (kernels can't return values)
%ptrs = tt.splat %ptr : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked1d>
tt.store %ptrs, %result : tensor<128x!tt.ptr<f32>, #blocked1d>
tt.return
}
}

// -----

// Test 2D Flash Attention pattern: alpha is 128x1 (broadcast dim), acc/scaled_acc are 128x64
// This tests the broadcast scenario where alpha has a singleton dimension.
#blocked2d = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2d_alpha = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: uniform_tensor_select_2d_broadcast
// CHECK: nvvm.vote.sync ballot
// CHECK: llvm.icmp "ne"
// CHECK: llvm.cond_br
// CHECK: llvm.br
// CHECK: llvm.br
tt.func @uniform_tensor_select_2d_broadcast(%mask: i32, %alpha: tensor<128x1xf32, #blocked2d_alpha>, %acc: tensor<128x64xf32, #blocked2d>, %scaled_acc: tensor<128x64xf32, #blocked2d>, %ptr: !tt.ptr<f32>) {
// pred = alpha < 1.0 - alpha is 128x1, will be broadcast
%c1 = arith.constant dense<1.0> : tensor<128x1xf32, #blocked2d_alpha>
%pred = arith.cmpf olt, %alpha, %c1 : tensor<128x1xf32, #blocked2d_alpha>
// ballot_result has same shape as pred (128x1)
%ballot = ttng.vote_ballot_sync %mask, %pred : tensor<128x1xi1, #blocked2d_alpha> -> tensor<128x1xi32, #blocked2d_alpha>
// should_rescale = ballot_result != 0 (128x1)
%c0 = arith.constant dense<0> : tensor<128x1xi32, #blocked2d_alpha>
%cond_small = arith.cmpi ne, %ballot, %c0 : tensor<128x1xi32, #blocked2d_alpha>
// Broadcast condition from 128x1 to 128x64 to match acc/scaled_acc shape
%should_rescale = tt.broadcast %cond_small : tensor<128x1xi1, #blocked2d_alpha> -> tensor<128x64xi1, #blocked2d>
// Conditional select with broadcast condition
%result = arith.select %should_rescale, %scaled_acc, %acc : tensor<128x64xi1, #blocked2d>, tensor<128x64xf32, #blocked2d>
// Store result
%ptrs = tt.splat %ptr : !tt.ptr<f32> -> tensor<128x64x!tt.ptr<f32>, #blocked2d>
tt.store %ptrs, %result : tensor<128x64x!tt.ptr<f32>, #blocked2d>
tt.return
}
}

// -----

#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
Expand Down Expand Up @@ -128,6 +221,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

// -----

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: vote_ballot_sync
// CHECK: nvvm.vote.sync ballot
tt.func @vote_ballot_sync(%mask: i32, %pred: i1) {
%result = ttng.vote_ballot_sync %mask, %pred : i1 -> i32
tt.return
}
}

// -----

// Test that scalar select with warp-uniform condition (from vote_ballot) is
// converted to branches instead of select instruction.
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: uniform_select_to_branch
// CHECK: nvvm.vote.sync ballot
// CHECK: llvm.icmp "ne"
// CHECK: llvm.cond_br
// CHECK: llvm.br
// CHECK: llvm.br
tt.func @uniform_select_to_branch(%mask: i32, %pred: i1, %true_val: i32, %false_val: i32, %ptr: !tt.ptr<i32>) {
%ballot = ttng.vote_ballot_sync %mask, %pred : i1 -> i32
%c0 = arith.constant 0 : i32
%cond = arith.cmpi ne, %ballot, %c0 : i32
%result = arith.select %cond, %true_val, %false_val : i32
// Store result (kernels can't return values)
tt.store %ptr, %result : !tt.ptr<i32>
tt.return
}
}

// -----

#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
Expand Down
Loading
Loading