Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,14 @@ TLX uses **CUDA-native cluster semantics** which differs from Triton's approach:
y = tlx.stoch_round(x, tlx.dtype_of(y_ptr), rbits)
```

- `tlx.vote_ballot_sync(mask, pred)`

Collects a predicate from each thread in the warp and returns a 32-bit
mask where each bit represents the predicate value from the corresponding
lane. Only threads specified by `mask` participate in the vote.
```
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)
```

## Kernels Implemented with TLX

Expand Down
39 changes: 39 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,45 @@ def TTNG_CLCQueryCancelOp : TTNG_Op<"clc_query_cancel", []> {
let assemblyFormat = "$clcResAlloc attr-dict `:` functional-type(operands, $ctaId)";
}

def TTNG_VoteBallotSyncOp : TTNG_Op<"vote_ballot_sync", [Pure]> {
let summary = "Warp-level vote ballot synchronization";

let description = [{
Performs a warp-level vote ballot operation that collects a predicate from
each thread in the warp and returns a 32-bit mask where each bit represents
the predicate value from the corresponding lane.

The `mask` operand specifies which threads participate in the vote. Threads
with their corresponding bit set in the mask must execute the instruction
with the same mask value.

The `pred` operand can be either:
- A scalar i1: Each thread contributes this predicate, returns scalar i32
- A tensor of i1: Each thread contributes its element(s), returns tensor of i32
with the same shape. All threads in a warp receive the same ballot value.

When pred is a tensor, each thread contributes the OR of all its owned
elements to the ballot. The result tensor has the same shape, with each
element containing the warp's ballot result.

This lowers to PTX instruction:
vote.sync.ballot.b32 dest, predicate, membermask;

https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync
}];

let arguments = (ins
I32:$mask,
AnyTypeOf<[I1, TT_BoolTensor]>:$pred
);

let results = (outs AnyTypeOf<[I32, TT_IntTensor]>:$result);

let assemblyFormat = "$mask `,` $pred attr-dict `:` type($pred) `->` type($result)";

let hasVerifier = 1;
}

def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [AttrSizedOperandSegments]> {
let summary = "copy data based on descriptor from global memory to local memory asynchronously";

Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
triton::gpu::AsyncCopyGlobalToLocalOp, triton::gpu::LocalLoadOp,
triton::gpu::LocalStoreOp, triton::gpu::RemoteShmemStoreOp,
triton::gpu::AsyncRemoteShmemStoreOp,
triton::nvidia_gpu::WarpGroupDotWaitOp, triton::tlx::RequireLayoutOp,
triton::nvidia_gpu::WarpGroupDotWaitOp,
triton::nvidia_gpu::VoteBallotSyncOp, triton::tlx::RequireLayoutOp,
triton::tlx::ReleaseLayoutOp, triton::tlx::LocalAliasOp>(
[&](Operation *op) -> bool {
// make sure every RankedTensorType operand has encoding
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::gpu::AsyncRemoteShmemStoreOp>,
GenericOpPattern<triton::gpu::LocalLoadOp>,
GenericOpPattern<triton::nvidia_gpu::WarpGroupDotWaitOp>,
GenericOpPattern<triton::nvidia_gpu::VoteBallotSyncOp>,
TritonFuncOpPattern>(typeConverter, context);
}

Expand Down
56 changes: 56 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,62 @@ LogicalResult ArriveBarrierOp::verify() {
return success();
}

// -- VoteBallotSyncOp --
LogicalResult VoteBallotSyncOp::verify() {
Type predType = getPred().getType();
Type resultType = getResult().getType();

bool predIsTensor = isa<RankedTensorType>(predType);
bool resultIsTensor = isa<RankedTensorType>(resultType);

// Both must be scalars or both must be tensors
if (predIsTensor != resultIsTensor) {
return emitOpError("predicate and result must both be scalars or both be "
"tensors, got pred=")
<< predType << " and result=" << resultType;
}

if (predIsTensor) {
auto predTensorType = cast<RankedTensorType>(predType);
auto resultTensorType = cast<RankedTensorType>(resultType);

// Check element types
if (!predTensorType.getElementType().isInteger(1)) {
return emitOpError("tensor predicate must have i1 element type, got ")
<< predTensorType.getElementType();
}
if (!resultTensorType.getElementType().isInteger(32)) {
return emitOpError("tensor result must have i32 element type, got ")
<< resultTensorType.getElementType();
}

// Shapes must match
if (predTensorType.getShape() != resultTensorType.getShape()) {
return emitOpError("predicate and result tensor shapes must match, got ")
<< predTensorType.getShape() << " vs "
<< resultTensorType.getShape();
}

// Encodings must match (if present)
if (predTensorType.getEncoding() != resultTensorType.getEncoding()) {
return emitOpError(
"predicate and result tensor encodings must match, got ")
<< predTensorType.getEncoding() << " vs "
<< resultTensorType.getEncoding();
}
} else {
// Scalar case
if (!predType.isInteger(1)) {
return emitOpError("scalar predicate must be i1, got ") << predType;
}
if (!resultType.isInteger(32)) {
return emitOpError("scalar result must be i32, got ") << resultType;
}
}

return success();
}

// -- AsyncTMACopyGlobalToLocalOp --
LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
if (failed(verifyBarrierType(*this, getBarrier().getType())))
Expand Down
60 changes: 59 additions & 1 deletion python/test/unit/language/test_tlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4999,7 +4999,6 @@ def test_reuse_group_invalid_element_type_raises_error(self):
group_type=tlx.reuse_group_type.shared,
)


class TestToMxfp8:
"""Tests for the _to_mxfp8_block library function callable from JIT code with VEC_SIZE=32."""

Expand Down Expand Up @@ -5185,3 +5184,62 @@ def set_buffer_overlap_nested_kernel(BLOCK_SIZE: tl.constexpr):
# The kernel should compile to IR but fail during lowering
with pytest.raises(RuntimeError):
set_buffer_overlap_nested_kernel[grid](BLOCK_SIZE=64)

@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
def test_vote_ballot_sync(device):
"""Test vote_ballot_sync TLX operation for warp-level voting."""

@triton.jit
def vote_ballot_kernel(
output_ptr,
BLOCK_SIZE: tl.constexpr,
):
# Each thread's lane ID (use x-axis thread ID)
tid = tlx.thread_id(0)

# Create a predicate: lanes 0-15 vote True, lanes 16-31 vote False
pred = tid < 16

# Perform warp-level ballot vote
# 0xFFFFFFFF means all 32 threads in the warp participate
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)

# Store the ballot result from thread 0 only
if tid == 0:
tl.store(output_ptr, ballot_result)

output = torch.zeros(1, dtype=torch.int32, device=device)

# Run the kernel with 1 warp
vote_ballot_kernel[(1, )](output, BLOCK_SIZE=32, num_warps=1)
torch.cuda.synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why needs torch.cuda.synchronize()?


# Expected ballot result: threads 0-15 have pred=True, threads 16-31 have pred=False
# So ballot should be 0x0000FFFF (lower 16 bits set)
expected_ballot = 0x0000FFFF
assert output.item() == expected_ballot, f"Expected {hex(expected_ballot)}, got {hex(output.item())}"


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
def test_vote_ballot_sync_ir_emission(device):
"""Test that vote_ballot_sync generates the correct IR."""

@triton.jit
def vote_ballot_ir_kernel(output_ptr, ):
tid = tlx.thread_id(0)
pred = tid < 16 # First 16 threads True
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)
if tid == 0:
tl.store(output_ptr, ballot_result)

output = torch.zeros(1, dtype=torch.int32, device=device)
kernel = vote_ballot_ir_kernel[(1, )](output, num_warps=1)

# Verify the TTGIR contains the vote_ballot_sync op
ttgir = kernel.asm["ttgir"]
assert "vote_ballot_sync" in ttgir, "Expected vote_ballot_sync in TTGIR"

# Verify the LLVM IR contains the NVVM vote instruction
llir = kernel.asm["llir"]
assert "nvvm.vote.ballot.sync" in llir or "vote.sync.ballot" in llir, (
"Expected nvvm.vote.ballot.sync or vote.sync.ballot in LLVM IR")
11 changes: 11 additions & 0 deletions test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

// -----

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: vote_ballot_sync
// CHECK: nvvm.vote.sync ballot
tt.func @vote_ballot_sync(%mask: i32, %pred: i1) {
%result = ttng.vote_ballot_sync %mask, %pred : i1 -> i32
tt.return
}
}

// -----

#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
Expand Down
63 changes: 63 additions & 0 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,68 @@ struct CLCQueryCancelOpConversion
return success();
}
};

struct VoteBallotSyncOpConversion
: public ConvertOpToLLVMPattern<triton::nvidia_gpu::VoteBallotSyncOp> {
using ConvertOpToLLVMPattern<
triton::nvidia_gpu::VoteBallotSyncOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::nvidia_gpu::VoteBallotSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Type predType = op.getPred().getType();

// Scalar case: simple pass-through to NVVM
if (!isa<RankedTensorType>(predType)) {
Value result = rewriter.create<NVVM::VoteSyncOp>(
loc, rewriter.getI32Type(), adaptor.getMask(), adaptor.getPred(),
NVVM::VoteSyncKind::ballot);
rewriter.replaceOp(op, result);
return success();
}

// Tensor case: unpack elements, apply ballot to each, pack results
auto predTensorType = cast<RankedTensorType>(predType);
auto resultType = op.getResult().getType();

// Unpack the tensor predicate elements - each thread owns some elements
SmallVector<Value> predElems =
unpackLLElements(loc, adaptor.getPred(), rewriter);

// For vote_ballot_sync with tensor predicates:
// 1. First, OR all local predicate elements together to get a single bool
// 2. Apply the ballot operation once with the combined predicate
// 3. Replicate the result to all elements of the output tensor

TritonLLVMOpBuilder b(loc, rewriter);

// Combine all local predicate elements with OR
Value combinedPred;
if (predElems.empty()) {
combinedPred = b.i1_val(false);
} else {
combinedPred = predElems[0];
for (size_t i = 1; i < predElems.size(); ++i) {
combinedPred = b.or_(combinedPred, predElems[i]);
}
}

// Perform the warp-level ballot with the combined predicate
Value ballot = rewriter.create<NVVM::VoteSyncOp>(
loc, rewriter.getI32Type(), adaptor.getMask(), combinedPred,
NVVM::VoteSyncKind::ballot);

// Replicate the ballot result to all elements of the output tensor
SmallVector<Value> resultElems(predElems.size(), ballot);

// Pack results back into tensor
Value packedResult = packLLElements(loc, getTypeConverter(), resultElems,
rewriter, resultType);
rewriter.replaceOp(op, packedResult);
return success();
}
};
} // namespace

void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns(
Expand All @@ -427,4 +489,5 @@ void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns(
patterns.add<NamedBarrierWaitOpConversion>(typeConverter, benefit);
patterns.add<AsyncCLCTryCancelOpConversion>(typeConverter, benefit);
patterns.add<CLCQueryCancelOpConversion>(typeConverter, benefit);
patterns.add<VoteBallotSyncOpConversion>(typeConverter, benefit);
}
20 changes: 20 additions & 0 deletions third_party/tlx/dialect/triton_tlx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,26 @@ void init_triton_tlx_ir(py::module &&m) {
self.create<mlir::arith::SelectOp>(isNegOne, tileId, offset);
return tileId;
})
.def("vote_ballot_sync",
[](TritonOpBuilder &self, Value mask, Value pred) -> Value {
auto &builder = self.getBuilder();
Type predType = pred.getType();

// Determine result type based on predicate type
Type resultType;
if (auto tensorType = dyn_cast<RankedTensorType>(predType)) {
// For tensor input, return tensor of i32 with same
// shape/encoding
resultType = RankedTensorType::get(tensorType.getShape(),
builder.getI32Type(),
tensorType.getEncoding());
} else {
// Scalar input -> scalar i32 result
resultType = builder.getI32Type();
}

return self.create<ttng::VoteBallotSyncOp>(resultType, mask, pred);
})
.def("create_async_TMA_load",
[](TritonOpBuilder &self, std::vector<Value> &multicastTargets,
Value desc, std::vector<Value> &coord, Value mbarrier, Value pred,
Expand Down
4 changes: 4 additions & 0 deletions third_party/tlx/language/tlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
thread_id,
)
from .mxfp8_utils import _to_mxfp8_block
from .warp_ops import (
vote_ballot_sync, )

__all__ = [
# async_tasks
Expand Down Expand Up @@ -160,4 +162,6 @@
"DummyRegisterLayoutEncoding",
# MXFP8
"_to_mxfp8_block",
# warp_ops
"vote_ballot_sync",
]
Loading
Loading