diff --git a/README.md b/README.md index 1b103d8667..033486c2fd 100644 --- a/README.md +++ b/README.md @@ -722,6 +722,14 @@ TLX uses **CUDA-native cluster semantics** which differs from Triton's approach: y = tlx.stoch_round(x, tlx.dtype_of(y_ptr), rbits) ``` +- `tlx.vote_ballot_sync(mask, pred)` + + Collects a predicate from each thread in the warp and returns a 32-bit + mask where each bit represents the predicate value from the corresponding + lane. Only threads specified by `mask` participate in the vote. + ``` + ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred) + ``` ## Kernels Implemented with TLX diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index e036de0e76..b632391531 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -348,6 +348,45 @@ def TTNG_CLCQueryCancelOp : TTNG_Op<"clc_query_cancel", []> { let assemblyFormat = "$clcResAlloc attr-dict `:` functional-type(operands, $ctaId)"; } +def TTNG_VoteBallotSyncOp : TTNG_Op<"vote_ballot_sync", [Pure]> { + let summary = "Warp-level vote ballot synchronization"; + + let description = [{ + Performs a warp-level vote ballot operation that collects a predicate from + each thread in the warp and returns a 32-bit mask where each bit represents + the predicate value from the corresponding lane. + + The `mask` operand specifies which threads participate in the vote. Threads + with their corresponding bit set in the mask must execute the instruction + with the same mask value. + + The `pred` operand can be either: + - A scalar i1: Each thread contributes this predicate, returns scalar i32 + - A tensor of i1: Each thread contributes its element(s), returns tensor of i32 + with the same shape. All threads in a warp receive the same ballot value. + + When pred is a tensor, each thread contributes the OR of all its owned + elements to the ballot. The result tensor has the same shape, with each + element containing the warp's ballot result. + + This lowers to PTX instruction: + vote.sync.ballot.b32 dest, predicate, membermask; + + https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync + }]; + + let arguments = (ins + I32:$mask, + AnyTypeOf<[I1, TT_BoolTensor]>:$pred + ); + + let results = (outs AnyTypeOf<[I32, TT_IntTensor]>:$result); + + let assemblyFormat = "$mask `,` $pred attr-dict `:` type($pred) `->` type($result)"; + + let hasVerifier = 1; +} + def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [AttrSizedOperandSegments]> { let summary = "copy data based on descriptor from global memory to local memory asynchronously"; diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 3d6b01722d..d336b98460 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -121,7 +121,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( triton::gpu::AsyncCopyGlobalToLocalOp, triton::gpu::LocalLoadOp, triton::gpu::LocalStoreOp, triton::gpu::RemoteShmemStoreOp, triton::gpu::AsyncRemoteShmemStoreOp, - triton::nvidia_gpu::WarpGroupDotWaitOp, triton::tlx::RequireLayoutOp, + triton::nvidia_gpu::WarpGroupDotWaitOp, + triton::nvidia_gpu::VoteBallotSyncOp, triton::tlx::RequireLayoutOp, triton::tlx::ReleaseLayoutOp, triton::tlx::LocalAliasOp>( [&](Operation *op) -> bool { // make sure every RankedTensorType operand has encoding diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 099162599f..159f187da4 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -648,6 +648,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, TritonFuncOpPattern>(typeConverter, context); } diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index bb95773d54..8acbdab502 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -212,6 +212,62 @@ LogicalResult ArriveBarrierOp::verify() { return success(); } +// -- VoteBallotSyncOp -- +LogicalResult VoteBallotSyncOp::verify() { + Type predType = getPred().getType(); + Type resultType = getResult().getType(); + + bool predIsTensor = isa(predType); + bool resultIsTensor = isa(resultType); + + // Both must be scalars or both must be tensors + if (predIsTensor != resultIsTensor) { + return emitOpError("predicate and result must both be scalars or both be " + "tensors, got pred=") + << predType << " and result=" << resultType; + } + + if (predIsTensor) { + auto predTensorType = cast(predType); + auto resultTensorType = cast(resultType); + + // Check element types + if (!predTensorType.getElementType().isInteger(1)) { + return emitOpError("tensor predicate must have i1 element type, got ") + << predTensorType.getElementType(); + } + if (!resultTensorType.getElementType().isInteger(32)) { + return emitOpError("tensor result must have i32 element type, got ") + << resultTensorType.getElementType(); + } + + // Shapes must match + if (predTensorType.getShape() != resultTensorType.getShape()) { + return emitOpError("predicate and result tensor shapes must match, got ") + << predTensorType.getShape() << " vs " + << resultTensorType.getShape(); + } + + // Encodings must match (if present) + if (predTensorType.getEncoding() != resultTensorType.getEncoding()) { + return emitOpError( + "predicate and result tensor encodings must match, got ") + << predTensorType.getEncoding() << " vs " + << resultTensorType.getEncoding(); + } + } else { + // Scalar case + if (!predType.isInteger(1)) { + return emitOpError("scalar predicate must be i1, got ") << predType; + } + if (!resultType.isInteger(32)) { + return emitOpError("scalar result must be i32, got ") << resultType; + } + } + + return success(); +} + // -- AsyncTMACopyGlobalToLocalOp -- LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { if (failed(verifyBarrierType(*this, getBarrier().getType()))) diff --git a/python/test/unit/language/test_tlx.py b/python/test/unit/language/test_tlx.py index be15eb4b40..f985bf584b 100644 --- a/python/test/unit/language/test_tlx.py +++ b/python/test/unit/language/test_tlx.py @@ -2633,8 +2633,8 @@ def descriptor_load_kernel(input_ptr, output_ptr, M, N, BLOCK_SIZE_M: tl.constex kernel = descriptor_load_kernel[grid](x, y, M, N, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, ctas_per_cga=(2, 2, 1)) - assert kernel.asm["ptx"].count( - "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster") == 1 + assert (kernel.asm["ptx"].count( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster") == 1) # x: # [ x0 | x2] # [ x1 | x3] @@ -4550,3 +4550,63 @@ def test_reuse_storage_mismatch_error_message(self): # We can't fully test the error without a kernel context, but we can # verify the storage_alias_spec's storage property is accessible assert buf.storage == tlx.storage_kind.smem + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer") +def test_vote_ballot_sync(device): + """Test vote_ballot_sync TLX operation for warp-level voting.""" + + @triton.jit + def vote_ballot_kernel( + output_ptr, + BLOCK_SIZE: tl.constexpr, + ): + # Each thread's lane ID (use x-axis thread ID) + tid = tlx.thread_id(0) + + # Create a predicate: lanes 0-15 vote True, lanes 16-31 vote False + pred = tid < 16 + + # Perform warp-level ballot vote + # 0xFFFFFFFF means all 32 threads in the warp participate + ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred) + + # Store the ballot result from thread 0 only + if tid == 0: + tl.store(output_ptr, ballot_result) + + output = torch.zeros(1, dtype=torch.int32, device=device) + + # Run the kernel with 1 warp + vote_ballot_kernel[(1, )](output, BLOCK_SIZE=32, num_warps=1) + torch.cuda.synchronize() + + # Expected ballot result: threads 0-15 have pred=True, threads 16-31 have pred=False + # So ballot should be 0x0000FFFF (lower 16 bits set) + expected_ballot = 0x0000FFFF + assert output.item() == expected_ballot, f"Expected {hex(expected_ballot)}, got {hex(output.item())}" + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer") +def test_vote_ballot_sync_ir_emission(device): + """Test that vote_ballot_sync generates the correct IR.""" + + @triton.jit + def vote_ballot_ir_kernel(output_ptr, ): + tid = tlx.thread_id(0) + pred = tid < 16 # First 16 threads True + ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred) + if tid == 0: + tl.store(output_ptr, ballot_result) + + output = torch.zeros(1, dtype=torch.int32, device=device) + kernel = vote_ballot_ir_kernel[(1, )](output, num_warps=1) + + # Verify the TTGIR contains the vote_ballot_sync op + ttgir = kernel.asm["ttgir"] + assert "vote_ballot_sync" in ttgir, "Expected vote_ballot_sync in TTGIR" + + # Verify the LLVM IR contains the NVVM vote instruction + llir = kernel.asm["llir"] + assert "nvvm.vote.ballot.sync" in llir or "vote.sync.ballot" in llir, ( + "Expected nvvm.vote.ballot.sync or vote.sync.ballot in LLVM IR") diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index a3a7b20f2a..e8f37e839d 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -13,6 +13,99 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +// Test that tensor select with warp-uniform condition (from vote_ballot via splat) +// is converted to branches instead of per-element select instructions. +// This is the pattern used in Flash Attention for conditional rescaling. +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: uniform_tensor_select_to_branch + // CHECK: nvvm.vote.sync ballot + // CHECK: llvm.icmp "ne" + // CHECK: llvm.cond_br + // CHECK: llvm.br + // CHECK: llvm.br + tt.func @uniform_tensor_select_to_branch(%mask: i32, %pred: i1, %true_val: tensor<16x32xf32, #blocked>, %false_val: tensor<16x32xf32, #blocked>, %ptr: !tt.ptr) { + // Get warp-uniform ballot result + %ballot = ttng.vote_ballot_sync %mask, %pred : i1 -> i32 + %c0 = arith.constant 0 : i32 + // Compare ballot result (scalar i32) - this is warp-uniform + %scalar_cond = arith.cmpi ne, %ballot, %c0 : i32 + // Splat scalar condition to tensor shape to match tensor operands + %cond = tt.splat %scalar_cond : i1 -> tensor<16x32xi1, #blocked> + // Select with uniform tensor condition - should become branches + %result = arith.select %cond, %true_val, %false_val : tensor<16x32xi1, #blocked>, tensor<16x32xf32, #blocked> + // Store result (kernels can't return values) + %ptrs = tt.splat %ptr : !tt.ptr -> tensor<16x32x!tt.ptr, #blocked> + tt.store %ptrs, %result : tensor<16x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// Test the full Flash Attention pattern: tensor predicate -> vote_ballot -> tensor condition -> select +// This matches the actual FA kernel pattern where pred = alpha_1 < 1.0 is a tensor. +#blocked1d = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: uniform_tensor_select_tensor_pred + // CHECK: nvvm.vote.sync ballot + // CHECK: llvm.icmp "ne" + // CHECK: llvm.cond_br + // CHECK: llvm.br + // CHECK: llvm.br + tt.func @uniform_tensor_select_tensor_pred(%mask: i32, %alpha: tensor<128xf32, #blocked1d>, %acc: tensor<128xf32, #blocked1d>, %scaled_acc: tensor<128xf32, #blocked1d>, %ptr: !tt.ptr) { + // pred = alpha < 1.0 - this is a tensor predicate + %c1 = arith.constant dense<1.0> : tensor<128xf32, #blocked1d> + %pred = arith.cmpf olt, %alpha, %c1 : tensor<128xf32, #blocked1d> + // ballot_result is a tensor with the same shape, all elements contain warp ballot + %ballot = ttng.vote_ballot_sync %mask, %pred : tensor<128xi1, #blocked1d> -> tensor<128xi32, #blocked1d> + // should_rescale = ballot_result != 0 + %c0 = arith.constant dense<0> : tensor<128xi32, #blocked1d> + %should_rescale = arith.cmpi ne, %ballot, %c0 : tensor<128xi32, #blocked1d> + // Conditional select - condition is uniform since ballot result is same for all threads in warp + %result = arith.select %should_rescale, %scaled_acc, %acc : tensor<128xi1, #blocked1d>, tensor<128xf32, #blocked1d> + // Store result (kernels can't return values) + %ptrs = tt.splat %ptr : !tt.ptr -> tensor<128x!tt.ptr, #blocked1d> + tt.store %ptrs, %result : tensor<128x!tt.ptr, #blocked1d> + tt.return + } +} + +// ----- + +// Test 2D Flash Attention pattern: alpha is 128x1 (broadcast dim), acc/scaled_acc are 128x64 +// This tests the broadcast scenario where alpha has a singleton dimension. +#blocked2d = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2d_alpha = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: uniform_tensor_select_2d_broadcast + // CHECK: nvvm.vote.sync ballot + // CHECK: llvm.icmp "ne" + // CHECK: llvm.cond_br + // CHECK: llvm.br + // CHECK: llvm.br + tt.func @uniform_tensor_select_2d_broadcast(%mask: i32, %alpha: tensor<128x1xf32, #blocked2d_alpha>, %acc: tensor<128x64xf32, #blocked2d>, %scaled_acc: tensor<128x64xf32, #blocked2d>, %ptr: !tt.ptr) { + // pred = alpha < 1.0 - alpha is 128x1, will be broadcast + %c1 = arith.constant dense<1.0> : tensor<128x1xf32, #blocked2d_alpha> + %pred = arith.cmpf olt, %alpha, %c1 : tensor<128x1xf32, #blocked2d_alpha> + // ballot_result has same shape as pred (128x1) + %ballot = ttng.vote_ballot_sync %mask, %pred : tensor<128x1xi1, #blocked2d_alpha> -> tensor<128x1xi32, #blocked2d_alpha> + // should_rescale = ballot_result != 0 (128x1) + %c0 = arith.constant dense<0> : tensor<128x1xi32, #blocked2d_alpha> + %cond_small = arith.cmpi ne, %ballot, %c0 : tensor<128x1xi32, #blocked2d_alpha> + // Broadcast condition from 128x1 to 128x64 to match acc/scaled_acc shape + %should_rescale = tt.broadcast %cond_small : tensor<128x1xi1, #blocked2d_alpha> -> tensor<128x64xi1, #blocked2d> + // Conditional select with broadcast condition + %result = arith.select %should_rescale, %scaled_acc, %acc : tensor<128x64xi1, #blocked2d>, tensor<128x64xf32, #blocked2d> + // Store result + %ptrs = tt.splat %ptr : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked2d> + tt.store %ptrs, %result : tensor<128x64x!tt.ptr, #blocked2d> + tt.return + } +} + +// ----- + #shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { @@ -128,6 +221,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: vote_ballot_sync + // CHECK: nvvm.vote.sync ballot + tt.func @vote_ballot_sync(%mask: i32, %pred: i1) { + %result = ttng.vote_ballot_sync %mask, %pred : i1 -> i32 + tt.return + } +} + +// ----- + +// Test that scalar select with warp-uniform condition (from vote_ballot) is +// converted to branches instead of select instruction. +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: uniform_select_to_branch + // CHECK: nvvm.vote.sync ballot + // CHECK: llvm.icmp "ne" + // CHECK: llvm.cond_br + // CHECK: llvm.br + // CHECK: llvm.br + tt.func @uniform_select_to_branch(%mask: i32, %pred: i1, %true_val: i32, %false_val: i32, %ptr: !tt.ptr) { + %ballot = ttng.vote_ballot_sync %mask, %pred : i1 -> i32 + %c0 = arith.constant 0 : i32 + %cond = arith.cmpi ne, %ballot, %c0 : i32 + %result = arith.select %cond, %true_val, %false_val : i32 + // Store result (kernels can't return values) + tt.store %ptr, %result : !tt.ptr + tt.return + } +} + +// ----- + #shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #smem = #ttg.shared_memory diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp index c55ff241d9..3271fd5718 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -412,6 +412,68 @@ struct CLCQueryCancelOpConversion return success(); } }; + +struct VoteBallotSyncOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::VoteBallotSyncOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::VoteBallotSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Type predType = op.getPred().getType(); + + // Scalar case: simple pass-through to NVVM + if (!isa(predType)) { + Value result = rewriter.create( + loc, rewriter.getI32Type(), adaptor.getMask(), adaptor.getPred(), + NVVM::VoteSyncKind::ballot); + rewriter.replaceOp(op, result); + return success(); + } + + // Tensor case: unpack elements, apply ballot to each, pack results + auto predTensorType = cast(predType); + auto resultType = op.getResult().getType(); + + // Unpack the tensor predicate elements - each thread owns some elements + SmallVector predElems = + unpackLLElements(loc, adaptor.getPred(), rewriter); + + // For vote_ballot_sync with tensor predicates: + // 1. First, OR all local predicate elements together to get a single bool + // 2. Apply the ballot operation once with the combined predicate + // 3. Replicate the result to all elements of the output tensor + + TritonLLVMOpBuilder b(loc, rewriter); + + // Combine all local predicate elements with OR + Value combinedPred; + if (predElems.empty()) { + combinedPred = b.i1_val(false); + } else { + combinedPred = predElems[0]; + for (size_t i = 1; i < predElems.size(); ++i) { + combinedPred = b.or_(combinedPred, predElems[i]); + } + } + + // Perform the warp-level ballot with the combined predicate + Value ballot = rewriter.create( + loc, rewriter.getI32Type(), adaptor.getMask(), combinedPred, + NVVM::VoteSyncKind::ballot); + + // Replicate the ballot result to all elements of the output tensor + SmallVector resultElems(predElems.size(), ballot); + + // Pack results back into tensor + Value packedResult = packLLElements(loc, getTypeConverter(), resultElems, + rewriter, resultType); + rewriter.replaceOp(op, packedResult); + return success(); + } +}; } // namespace void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns( @@ -427,4 +489,5 @@ void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index d32a1c5799..c64a0ab479 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -7,6 +7,7 @@ #include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" using namespace mlir::triton::gpu; @@ -927,6 +928,242 @@ struct OpToExternCallConversion private: StringRef funcName; }; + +/// Check if a value is warp-uniform (same across all threads in a warp). +/// This is true for: +/// 1. Results of vote_ballot_sync (all threads get the same 32-bit mask) +/// 2. Values derived from uniform values via compare/and/or/xor operations +/// 3. Scalar constants +/// 4. Splat tensors derived from uniform scalars +static bool isWarpUniform(Value val, int depth = 0) { + // Limit recursion depth to avoid infinite loops + if (depth > 10) + return false; + + Operation *defOp = val.getDefiningOp(); + if (!defOp) + return false; // Block arguments need additional analysis + + // vote_ballot_sync result is warp-uniform - all threads get the same result + // This works for both scalar and tensor results (tensor is a splat of the + // uniform scalar) + if (isa(defOp)) + return true; + + // Constants are uniform (both scalar and splat constants) + if (isa(defOp)) + return true; + + // Splat of a uniform value is uniform + if (auto splatOp = dyn_cast(defOp)) + return isWarpUniform(splatOp.getSrc(), depth + 1); + + // Broadcast of a uniform value is uniform + if (auto broadcastOp = dyn_cast(defOp)) + return isWarpUniform(broadcastOp.getSrc(), depth + 1); + + // Integer comparison of uniform values is uniform + if (auto cmpOp = dyn_cast(defOp)) { + return isWarpUniform(cmpOp.getLhs(), depth + 1) && + isWarpUniform(cmpOp.getRhs(), depth + 1); + } + + // LLVM integer comparison of uniform values is uniform + if (auto cmpOp = dyn_cast(defOp)) { + return isWarpUniform(cmpOp.getLhs(), depth + 1) && + isWarpUniform(cmpOp.getRhs(), depth + 1); + } + + // Binary ops on uniform values are uniform + if (auto andOp = dyn_cast(defOp)) { + return isWarpUniform(andOp.getLhs(), depth + 1) && + isWarpUniform(andOp.getRhs(), depth + 1); + } + if (auto orOp = dyn_cast(defOp)) { + return isWarpUniform(orOp.getLhs(), depth + 1) && + isWarpUniform(orOp.getRhs(), depth + 1); + } + if (auto xorOp = dyn_cast(defOp)) { + return isWarpUniform(xorOp.getLhs(), depth + 1) && + isWarpUniform(xorOp.getRhs(), depth + 1); + } + + // LLVM binary ops on uniform values are uniform + if (auto andOp = dyn_cast(defOp)) { + return isWarpUniform(andOp.getLhs(), depth + 1) && + isWarpUniform(andOp.getRhs(), depth + 1); + } + if (auto orOp = dyn_cast(defOp)) { + return isWarpUniform(orOp.getLhs(), depth + 1) && + isWarpUniform(orOp.getRhs(), depth + 1); + } + if (auto xorOp = dyn_cast(defOp)) { + return isWarpUniform(xorOp.getLhs(), depth + 1) && + isWarpUniform(xorOp.getRhs(), depth + 1); + } + + // Truncation/extension of uniform values are uniform + if (auto truncOp = dyn_cast(defOp)) + return isWarpUniform(truncOp.getIn(), depth + 1); + if (auto extOp = dyn_cast(defOp)) + return isWarpUniform(extOp.getIn(), depth + 1); + if (auto extOp = dyn_cast(defOp)) + return isWarpUniform(extOp.getIn(), depth + 1); + + // LLVM truncation/extension of uniform values + if (auto truncOp = dyn_cast(defOp)) + return isWarpUniform(truncOp.getArg(), depth + 1); + if (auto extOp = dyn_cast(defOp)) + return isWarpUniform(extOp.getArg(), depth + 1); + if (auto extOp = dyn_cast(defOp)) + return isWarpUniform(extOp.getArg(), depth + 1); + + return false; +} + +/// Try to extract a scalar i1 condition from a potentially tensor condition. +/// For warp-uniform tensor conditions (e.g., splatted from vote_ballot result), +/// all elements have the same value, so we can extract any element. +/// Returns the scalar condition if extractable, otherwise returns nullptr. +static Value extractScalarCondition(Value cond, Location loc, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter) { + // If already a scalar i1, return it directly + if (!isa(cond.getType())) { + if (cond.getType().isInteger(1)) + return cond; + return nullptr; + } + + // For tensor conditions, trace back through splat/broadcast to find scalar + Operation *defOp = cond.getDefiningOp(); + if (!defOp) + return nullptr; + + // Splat: tensor is created by splatting a scalar value + if (auto splatOp = dyn_cast(defOp)) { + return extractScalarCondition(splatOp.getSrc(), loc, rewriter, + typeConverter); + } + + // Broadcast: tensor is created by broadcasting a smaller tensor + if (auto broadcastOp = dyn_cast(defOp)) { + return extractScalarCondition(broadcastOp.getSrc(), loc, rewriter, + typeConverter); + } + + // For other tensor-producing ops where we know the value is uniform, + // we could potentially extract the first element. However, at this point + // in lowering, we need to be careful about how the tensor is represented. + // The condition tensor has already been lowered to LLVM struct of i1s. + // We can extract the first element from the struct. + return nullptr; +} + +/// Convert arith.SelectOp with warp-uniform condition to branches. +/// When the condition is uniform (same for all threads in a warp), we can +/// use branches instead of select instructions without causing warp divergence. +/// This can be beneficial when the true/false values are expensive to compute +/// (e.g., memory loads) since only one branch will be executed. +/// +/// Supports both scalar and tensor selects: +/// - Scalar: Direct branch on the scalar condition +/// - Tensor: Extract scalar condition from uniform tensor, branch on it, +/// and yield entire tensor values in each branch +struct UniformSelectToBranchConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // Check if condition is warp-uniform (no divergence) + if (!isWarpUniform(op.getCondition())) + return failure(); + + Value scalarCond = nullptr; + bool isTensorSelect = isa(op.getTrueValue().getType()); + + if (isTensorSelect) { + // For tensor selects, we need to extract a scalar condition. + // First, try to trace back through splat/broadcast ops to find the + // original scalar condition before it was splatted to a tensor. + scalarCond = extractScalarCondition(op.getCondition(), loc, rewriter, + getTypeConverter()); + + if (!scalarCond) { + // If we can't trace back to a scalar, the condition tensor has been + // lowered to an LLVM struct. Since it's uniform, all elements are + // the same, so extract the first element. + Value condStruct = adaptor.getCondition(); + Type condStructTy = condStruct.getType(); + + if (auto structTy = dyn_cast(condStructTy)) { + // Extract the first element from the struct + scalarCond = rewriter.create( + loc, condStruct, ArrayRef{0}); + } else { + // Condition is already a scalar (e.g., scalar condition with tensor + // operands) + scalarCond = condStruct; + } + } + } else { + // Scalar select - condition is already scalar + scalarCond = adaptor.getCondition(); + } + + if (!scalarCond || !scalarCond.getType().isInteger(1)) + return failure(); + + // Create branch-based lowering: + // %result = select %cond, %true, %false + // Becomes: + // cond_br %cond, ^trueBB, ^falseBB + // ^trueBB: + // br ^mergeBB(%true) + // ^falseBB: + // br ^mergeBB(%false) + // ^mergeBB(%result): + + Block *currentBlock = op->getBlock(); + Block::iterator insertPoint = op->getIterator(); + ++insertPoint; // Move past the current op + + // Split the block after the select op + Block *mergeBlock = rewriter.splitBlock(currentBlock, insertPoint); + + // Get the converted result type (for tensors, this is LLVM struct type) + Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); + + // Add block argument for the result + mergeBlock->addArgument(resultTy, loc); + + // Create true block - yields the true value (entire tensor if tensor + // select) + Block *trueBlock = rewriter.createBlock(mergeBlock); + rewriter.create(loc, adaptor.getTrueValue(), mergeBlock); + + // Create false block - yields the false value (entire tensor if tensor + // select) + Block *falseBlock = rewriter.createBlock(mergeBlock); + rewriter.create(loc, adaptor.getFalseValue(), mergeBlock); + + // Insert conditional branch at end of current block + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, scalarCond, trueBlock, falseBlock); + + // The merge block argument has the converted LLVM type. Replace the select + // with this value - the conversion framework will handle type + // reconciliation for any users through the type converter. + rewriter.replaceOp(op, mergeBlock->getArgument(0)); + + return success(); + } +}; + } // namespace } // namespace gpu @@ -946,6 +1183,15 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns( mlir::triton::populateElementwiseOpToLLVMPatterns( typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); + // Register UniformSelectToBranchConversion with higher priority. + // This pattern converts scalar select ops with warp-uniform conditions + // (e.g., derived from vote_ballot_sync) to branches. Since the condition + // is uniform across all threads in a warp, no divergence occurs. + // Higher benefit ensures this pattern is tried before the default select + // conversion when the condition is detected as warp-uniform. + patterns.add( + typeConverter, PatternBenefit(benefit.getBenefit() + 10)); + #define POPULATE_OP(SRC_OP, DST_OP) \ patterns.add>( \ typeConverter, axisInfoAnalysis, benefit) diff --git a/third_party/tlx/dialect/triton_tlx.cc b/third_party/tlx/dialect/triton_tlx.cc index 068721464f..861addc71c 100644 --- a/third_party/tlx/dialect/triton_tlx.cc +++ b/third_party/tlx/dialect/triton_tlx.cc @@ -583,6 +583,26 @@ void init_triton_tlx_ir(py::module &&m) { self.create(isNegOne, tileId, offset); return tileId; }) + .def("vote_ballot_sync", + [](TritonOpBuilder &self, Value mask, Value pred) -> Value { + auto &builder = self.getBuilder(); + Type predType = pred.getType(); + + // Determine result type based on predicate type + Type resultType; + if (auto tensorType = dyn_cast(predType)) { + // For tensor input, return tensor of i32 with same + // shape/encoding + resultType = RankedTensorType::get(tensorType.getShape(), + builder.getI32Type(), + tensorType.getEncoding()); + } else { + // Scalar input -> scalar i32 result + resultType = builder.getI32Type(); + } + + return self.create(resultType, mask, pred); + }) .def("create_async_TMA_load", [](TritonOpBuilder &self, std::vector &multicastTargets, Value desc, std::vector &coord, Value mbarrier, Value pred, diff --git a/third_party/tlx/language/tlx/__init__.py b/third_party/tlx/language/tlx/__init__.py index 62d0b5775f..8abc8be6ef 100644 --- a/third_party/tlx/language/tlx/__init__.py +++ b/third_party/tlx/language/tlx/__init__.py @@ -74,6 +74,8 @@ stoch_round, thread_id, ) +from .warp_ops import ( + vote_ballot_sync, ) __all__ = [ # async_tasks @@ -152,4 +154,6 @@ "clc_consumer", "CLCPipelineContext", "DummyRegisterLayoutEncoding", + # warp_ops + "vote_ballot_sync", ] diff --git a/third_party/tlx/language/tlx/warp_ops.py b/third_party/tlx/language/tlx/warp_ops.py new file mode 100644 index 0000000000..930a49be0b --- /dev/null +++ b/third_party/tlx/language/tlx/warp_ops.py @@ -0,0 +1,73 @@ +""" +TLX Warp-Level Operations + +This module provides warp-level synchronization and voting primitives +for NVIDIA GPUs. +""" + +import triton.language.core as tl + + +@tl.builtin +def vote_ballot_sync( + mask: tl.constexpr, + pred: tl.tensor, + _semantic=None, +) -> tl.tensor: + """ + Perform a warp-level vote ballot operation. + + Collects a predicate from each thread in the warp and returns a 32-bit + mask where each bit represents the predicate value from the corresponding + lane. Only threads specified by `mask` participate in the vote. + + Args: + mask: A 32-bit mask specifying which threads participate. Threads with + their corresponding bit set in the mask must execute with the + same mask value. Use 0xFFFFFFFF for all threads. + pred: A boolean predicate. Can be either a scalar i1 or a tensor of i1 + + Returns: + If pred is scalar: A 32-bit integer where bit N is set if thread N's + predicate was true and thread N is in the mask. + If pred is tensor: A tensor of i32 with the same shape, where each + element contains the warp's ballot result. + + Example: + # Scalar predicate - check if any thread has a non-zero value + ballot = tlx.vote_ballot_sync(0xFFFFFFFF, x != 0) + + # Tensor predicate - it will be distributed to warps/threads according to layout + pred_tensor = values < threshold # tensor<128x1xi1> + ballot = tlx.vote_ballot_sync(0xFFFFFFFF, pred_tensor) # tensor<128x1xi32> + + PTX instruction generated: + vote.sync.ballot.b32 dest, predicate, membermask; + + Note: + - All threads in mask must execute the instruction with identical mask + - The sync variant ensures warp convergence before the vote + """ + # Ensure pred is i1/bool type + if pred.dtype != tl.int1: + pred = pred != 0 + + # Get mask as i32 value + if isinstance(mask, tl.constexpr): + mask_val = mask.value + else: + mask_val = mask + + mask_handle = _semantic.builder.get_int32(mask_val) + result = _semantic.builder.vote_ballot_sync(mask_handle, pred.handle) + + # Determine result type based on predicate type + # If pred is a tensor, result will be tensor of i32 with same shape + if pred.type.is_block(): + # Tensor case - create block_type with same shape but i32 element type + shape = [s.value if hasattr(s, "value") else s for s in pred.shape] + ret_ty = tl.block_type(tl.int32, shape) + return _semantic.tensor(result, ret_ty) + else: + # Scalar case + return _semantic.tensor(result, tl.int32) diff --git a/third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py b/third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py index 64b63e04da..c41e41180d 100644 --- a/third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py +++ b/third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py @@ -29,61 +29,21 @@ def _host_descriptor_pre_hook(nargs): "BLOCK_M": 256, "BLOCK_N": 128, "NUM_BUFFERS_Q": 1, - "NUM_BUFFERS_KV": 3, + "NUM_BUFFERS_KV": kv, "NUM_BUFFERS_QK": 1, "NUM_MMA_GROUPS": 2, "NUM_MMA_SLICES": 2, - "GROUP_SIZE_N": 1, + "GROUP_SIZE_N": grp_n, + "RESCALE_OPT": rescale_opt, + "USE_WHERE": where, # used when RESCALE_OPT is True }, num_stages=0, num_warps=4, pre_hook=_host_descriptor_pre_hook, - ), - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 128, - "NUM_BUFFERS_Q": 1, - "NUM_BUFFERS_KV": 3, - "NUM_BUFFERS_QK": 1, - "NUM_MMA_GROUPS": 2, - "NUM_MMA_SLICES": 2, - "GROUP_SIZE_N": 4, - }, - num_stages=0, - num_warps=4, - pre_hook=_host_descriptor_pre_hook, - ), - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 128, - "NUM_BUFFERS_Q": 1, - "NUM_BUFFERS_KV": 6, - "NUM_BUFFERS_QK": 1, - "NUM_MMA_GROUPS": 2, - "NUM_MMA_SLICES": 2, - "GROUP_SIZE_N": 1, - }, - num_stages=0, - num_warps=4, - pre_hook=_host_descriptor_pre_hook, - ), - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 128, - "NUM_BUFFERS_Q": 1, - "NUM_BUFFERS_KV": 6, - "NUM_BUFFERS_QK": 1, - "NUM_MMA_GROUPS": 2, - "NUM_MMA_SLICES": 2, - "GROUP_SIZE_N": 4, - }, - num_stages=0, - num_warps=4, - pre_hook=_host_descriptor_pre_hook, - ), + ) + for kv in [3, 6] + for grp_n in [1, 4] + for (rescale_opt, where) in [(False, False), (True, False), (True, True)] ] @@ -105,6 +65,11 @@ def _get_bufidx_phase(accum_cnt, NUM_BUFFERS_KV): return bufIdx, phase +@triton.jit +def _reduce_or(x, y): + return x | y + + @triton.jit def _mul_f32x2(a, b): return tl.inline_asm_elementwise( @@ -262,6 +227,126 @@ def _apply_causal_mask(qk, col_limit_right, BLOCK_N: tl.constexpr): return tl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i) +# Original add_round_down(x, y) did: add.rm.ftz.f32 (round-down) +# We only need it to produce floor(x) when paired with the big-constant trick. +# In Triton, just compute floor(x) explicitly and add y. +@triton.jit +def add_round_down(a, b): + return tl.inline_asm_elementwise( + """ + { + .reg .b64 ra, rb, rc; + mov.b64 ra, { $2, $3 }; + mov.b64 rb, { $4, $5 }; + add.rm.ftz.f32x2 rc, ra, rb; + mov.b64 { $0, $1 }, rc; + } + """, + "=r,=r,r,r,r,r", + [a, b], + dtype=tl.float32, + is_pure=True, + pack=2, + ) + + +# ============================================================================ +# Custom exp2 Polynomial Approximation (the "emulation" path) +# ============================================================================ + + +@triton.jit +def _fma_f32x2(a, b, c): + return tl.inline_asm_elementwise( + """ + { + .reg .b64 ra, rb, rc, rd; + mov.b64 ra, { $2, $3 }; + mov.b64 rb, { $4, $5 }; + mov.b64 rc, { $6, $7 }; + fma.rn.f32x2 rd, ra, rb, rc; + mov.b64 { $0, $1 }, rd; + } + """, + "=r,=r,r,r,r,r,r,r", + [a, b, c], + dtype=tl.float32, + is_pure=True, + pack=2, + ) + + +@triton.jit +def evaluate_polynomial(x, coeffs: tl.constexpr): + """ + Returns P(x) with scalar coeffs using Horner's rule and inline PTX FMA. + P(t) = sum_{i=0..deg} coeffs[i] * t^i + """ + deg = len(coeffs) - 1 + out = coeffs[deg] + for i in range(deg - 1, -1, -1): + out = _fma_f32x2(out, x, coeffs[i]) # out = out*x + coeffs[i] + return out + + +@triton.jit +def combine_int_frac_ex2(x_rounded, frac_ex2): + """ + Compose: out_bits = (x_rounded_bits << 23) + frac_ex2_bits + where: + - x_rounded carries the integer part (floor) in its low bits + - frac_ex2 carries the mantissa approximation of 2**fraction(x) + Returns fp32 with those bits. + """ + out = tl.inline_asm_elementwise( + asm=""" + { + .reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i; + mov.b32 x_rounded_i, $1; + mov.b32 frac_ex_i, $2; + shl.b32 x_rounded_e, x_rounded_i, 23; + add.s32 out_i, x_rounded_e, frac_ex_i; + mov.b32 $0, out_i; + } + """, + constraints="=f, f, f", # $0: out(float), $1: x_rounded(float), $2: frac_ex2(float) + args=[x_rounded, frac_ex2], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + return out + + +@triton.jit +def ex2_emulation(x): + # We assume x <= 127.0 + x = tl.maximum(x, -127.0) + + FP32_ROUND_INT = (2.0**23) + (2.0**22) # same constant you used + + # Emulate your “round down fractional” split explicitly + x_rounded = add_round_down(x, FP32_ROUND_INT) + x_rounded_back = x_rounded - FP32_ROUND_INT + x_frac = x - x_rounded_back # r in [0,1) + + # Your degree-3 approximation coefficients for 2**r, r in [0,1) + + # ---- inline polynomial for 2**r (degree-3, Horner) ---- + # Coeffs: (C0, C1, C2, C3) + C0 = 1.0 + C1 = 0.695146143436431884765625 + C2 = 0.227564394474029541015625 + C3 = 0.077119089663028717041015625 + + x_frac_ex2 = C3 + x_frac_ex2 = _fma_f32x2(x_frac_ex2, x_frac, C2) # t = C3*r + C2 + x_frac_ex2 = _fma_f32x2(x_frac_ex2, x_frac, C1) # t = (..)*r + C1 + x_frac_ex2 = _fma_f32x2(x_frac_ex2, x_frac, C0) # t = (..)*r + C0 ~ 2**r + + return combine_int_frac_ex2(x_rounded, x_frac_ex2) # 2**n * 2**r + + @triton.jit def _softmax_inner_loop( qk_fulls, @@ -287,6 +372,7 @@ def _softmax_inner_loop( NUM_MMA_GROUPS: tl.constexpr, STAGE: tl.constexpr, P_PADDING: tl.constexpr, + RESCALE_OPT: tl.constexpr, ): lo, hi = _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE) @@ -300,16 +386,47 @@ def _softmax_inner_loop( qk = _apply_causal_mask(qk, col_limit_right, BLOCK_N) # compute m_i, p in registers - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + # update_row_max: row_max_new = _compute_row_max(qk, row_max[0]) + # -> FA4 handles one row per thread (32 threads per warp * 4) + # -> use fmax_reduce(one row of qk, m_i[0]) + # -> m_i|m_ij = row_max[0] * scale + if RESCALE_OPT: + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) # -- compute correction factor - alpha = tl.math.exp2(m_i - m_ij) + # update_row_max: acc_scale_ = (row_max[0] - row_max_new) * scale + # -> acc_scale = exp2(acc_scale_) + # -> if (acc_scale_ >= -8.0): + # -> row_max_new = row_max[0]; acc_scale = 1.0 + # -> row_max[0] = row_max_new + if RESCALE_OPT: + alpha_ = (m_i - m_ij) * qk_scale # alpha_ is 1D distributed over the warp group + alpha = tl.math.exp2(alpha_) + rescale_mask = alpha_ >= -8.0 + alpha = tl.where(rescale_mask, 1.0, alpha) + m_ij = tl.where(rescale_mask, m_i, m_ij) + else: + alpha = tl.math.exp2(m_i - m_ij) tlx.barrier_wait(tlx.local_view(alpha_empties, cid), qk_phase ^ 1) # Use alpha[0] for cid=0, and alpha[BLOCK_N] for cid=1 tlx.local_store(tlx.local_view(alpha_tiles, cid * BLOCK_N), alpha[:, None]) tlx.barrier_arrive(tlx.local_view(alpha_fulls, cid)) - qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) + # scale_subtract_rowmax: + # -> row_max_scaled = row_max_new * scale + # -> s[i], s[i+1] = fma_packed_f32x2((s[i], s[i+1]), (scale, scale), (-row_max_scaled, -row_max_scaled)) + if RESCALE_OPT: + m_scaled = m_ij * qk_scale + qk = _fma_f32x2(qk, qk_scale, -m_scaled[:, None]) + else: + qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) + # apply_epx2_convert in FA4: + # 128 elements per row is divided into 4 fragments, first fragement covers [0] to [31] + # for last fragment, always use SFU, for first 3 fragments, elements 0 to 11 use SFU, + # elements 12 to 15 use emulation, elements 16 to 27 use SFU, elements 28 to 31 use emulation + # the loop is unrolled twice likely for vectorization qks = _split_n(qk, NUM_MMA_SLICES) ps = () for slice_id in tl.static_range(0, NUM_MMA_SLICES): @@ -351,6 +468,8 @@ def _attn_fwd_ws(sm_scale, M, # NUM_MMA_GROUPS: tl.constexpr, # NUM_MMA_SLICES: tl.constexpr, # GROUP_SIZE_N: tl.constexpr, # + RESCALE_OPT: tl.constexpr, # + USE_WHERE: tl.constexpr, # ): tl.static_assert(NUM_MMA_GROUPS == 2) tl.static_assert(NUM_BUFFERS_QK == 1) @@ -481,16 +600,53 @@ def _attn_fwd_ws(sm_scale, M, # # Use alpha[0] for cid=0, and alpha[BLOCK_N] for cid=1 alpha_1 = tlx.local_load(alpha_tiles[cid * BLOCK_N]) tlx.barrier_arrive(alpha_empties[cid]) - for slice_id in tl.static_range(0, NUM_MMA_SLICES): - subslice = tlx.subslice( - acc_tiles[cid], - HEAD_DIM * slice_id // NUM_MMA_SLICES, - HEAD_DIM // NUM_MMA_SLICES, - ) - acc = tlx.local_load(subslice) - # acc = acc * alpha_1 - acc = _mul_f32x2(acc, alpha_1) - tlx.local_store(subslice, acc) + # Perform warp-level ballot vote to check if any thread needs rescaling + # 0xFFFFFFFF means all 32 threads in the warp participate + if RESCALE_OPT: + pred = alpha_1 < 1.0 + # ballot_result is a tensor with the same shape as pred + # All elements contain the same warp-level ballot value + # Non-zero means at least one thread has alpha_1 < 1.0 + ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred) + should_rescale = ballot_result != 0 + + # FA4: each thread handles one row, 128 elements + # 128 threads handle 128 rows + # each thread breaks one row into 8 fragments, each fragment 16 elements, unrolls by 2 + # TLX: with NUM_MMA_SLICES of 2, we handle 128x64, then another 128x64 + # Since Triton doesn't support ifOp on a tensor value, we try to combine the values + # option 1: use tl.where + if USE_WHERE: + for slice_id in tl.static_range(0, NUM_MMA_SLICES): + subslice = tlx.subslice( + acc_tiles[cid], + HEAD_DIM * slice_id // NUM_MMA_SLICES, + HEAD_DIM // NUM_MMA_SLICES, + ) + acc = tlx.local_load(subslice) + # Use tl.where to conditionally apply rescaling + # acc = acc * alpha_1 where should_rescale, else acc unchanged + if RESCALE_OPT: + scaled_acc = _mul_f32x2(acc, alpha_1) + acc = tl.where(should_rescale, scaled_acc, acc) + else: + acc = _mul_f32x2(acc, alpha_1) + tlx.local_store(subslice, acc) + else: + # option 2: use a single scalar IfOp + if RESCALE_OPT: + should_rescale_red = tl.reduce(should_rescale, axis=0, combine_fn=_reduce_or) + should_rescale_scalar = tl.reshape(should_rescale_red, ()) + if not RESCALE_OPT or (RESCALE_OPT and should_rescale_scalar): + for slice_id in tl.static_range(0, NUM_MMA_SLICES): + subslice = tlx.subslice( + acc_tiles[cid], + HEAD_DIM * slice_id // NUM_MMA_SLICES, + HEAD_DIM // NUM_MMA_SLICES, + ) + acc = tlx.local_load(subslice) + acc = _mul_f32x2(acc, alpha_1) + tlx.local_store(subslice, acc) tlx.barrier_arrive(acc_fulls[cid]) accum_cnt += 1 @@ -506,7 +662,7 @@ def _attn_fwd_ws(sm_scale, M, # # since both tiles share the same synchronization group. tlx.barrier_arrive(qk_empties[cid]) m += tl.math.log2(l) - offs_m = (start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT)) + offs_m = start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT) m_ptrs = M + off_hz * N_CTX + offs_m tl.store(m_ptrs, tl.reshape(m, [BLOCK_M_SPLIT])) @@ -549,6 +705,8 @@ def _attn_fwd_ws(sm_scale, M, # ) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf") + # FA4 update_row_sum has init_val being None for the first iteration, here + # we use initial value of 1.0 l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M_SPLIT, HEAD_DIM], dtype=tl.float32) qk_scale = sm_scale @@ -582,6 +740,7 @@ def _attn_fwd_ws(sm_scale, M, # NUM_MMA_GROUPS, STAGE=4 - STAGE, P_PADDING=P_PADDING, + RESCALE_OPT=RESCALE_OPT, ) if STAGE & 2: @@ -609,6 +768,7 @@ def _attn_fwd_ws(sm_scale, M, # NUM_MMA_GROUPS, STAGE=2, P_PADDING=P_PADDING, + RESCALE_OPT=RESCALE_OPT, ) # prepare l_i for the epilog @@ -729,7 +889,7 @@ def _attn_fwd_ws(sm_scale, M, # ) p_bufIdx = (NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES) * P_PADDING + slice_id use_acc = acc1_init if slice_id == 0 else True - mBarriers = ([kv_empties[v_bufIdx_prev]] if slice_id == NUM_MMA_SLICES - 1 else []) + mBarriers = [kv_empties[v_bufIdx_prev]] if slice_id == NUM_MMA_SLICES - 1 else [] tlx.async_dot( p_tiles[p_bufIdx], kv_slice, @@ -787,7 +947,7 @@ def _attn_fwd_ws(sm_scale, M, # ) p_bufIdx = (NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES) * P_PADDING + slice_id use_acc = acc1_init if slice_id == 0 else True - mBarriers = ([acc_empties[1], kv_empties[v_bufIdx]] if slice_id == NUM_MMA_SLICES - 1 else []) + mBarriers = [acc_empties[1], kv_empties[v_bufIdx]] if slice_id == NUM_MMA_SLICES - 1 else [] tlx.async_dot( p_tiles[p_bufIdx], kv_slice, @@ -1876,9 +2036,9 @@ def test_op( dtype, ): torch.manual_seed(20) - q = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=DEVICE).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=DEVICE).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=DEVICE).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_()) + q = torch.empty((Z, H, N_CTX, HEAD_DIM), device=DEVICE).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_() + k = torch.empty((Z, H, N_CTX, HEAD_DIM), device=DEVICE).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_() + v = torch.empty((Z, H, N_CTX, HEAD_DIM), device=DEVICE).normal_(mean=0.0, std=0.5).to(dtype).requires_grad_() sm_scale = 0.5 # reference implementation if dtype == torch.float8_e5m2: @@ -1944,14 +2104,14 @@ def test_op( print(N_HEADS) # vary seq length for fixed head and batch=4 configs = [] -for mode in ["fwd", "bwd"]: - for causal in [False, True]: - for BWD_BLOCK_M1 in [64, 128]: +for mode in ["fwd"]: # , "bwd"]: + for causal in [False]: # , True]: + for BWD_BLOCK_M1 in [64]: # , 128]: for GROUP_SIZE_M in [8]: configs.append( triton.testing.Benchmark( x_names=["N_CTX"], - x_vals=[2**i for i in range(10, 15)], + x_vals=[2**i for i in range(13, 14)], line_arg="provider", line_vals=["triton-fp16"] + (["flash"] if HAS_FLASH else []), line_names=["Triton [FP16]"] + (["Flash-2"] if HAS_FLASH else []), diff --git a/third_party/tlx/tutorials/vote_ballot_select_test.py b/third_party/tlx/tutorials/vote_ballot_select_test.py new file mode 100644 index 0000000000..465b336efe --- /dev/null +++ b/third_party/tlx/tutorials/vote_ballot_select_test.py @@ -0,0 +1,307 @@ +""" +Test case extracted from blackwell-fa-ws-pipelined-persistent_test.py +demonstrating vote_ballot_sync usage to guard conditional computation. + +This pattern is used in Flash Attention for conditional rescaling: +- Use vote_ballot_sync to check if ANY thread in the warp needs rescaling +- If the ballot result is zero (no thread needs rescaling), skip the computation +- This avoids computing the scaled value when all threads have alpha >= 1.0 +""" + +import torch +import triton +import triton.language as tl + +try: + import triton.language.extra.tlx as tlx + + HAS_TLX = True +except ImportError: + HAS_TLX = False + + +@triton.jit +def _mul_f32x2(a, b): + """Multiply two f32 values element-wise (simulating f32x2 packed ops).""" + return a * b + + +@triton.jit +def vote_ballot_select_kernel( + # Pointers + acc_ptr, + alpha_ptr, + out_ptr, + # Shape + M: tl.constexpr, + N: tl.constexpr, + # Strides + stride_acc_m, + stride_acc_n, + stride_alpha_m, + stride_out_m, + stride_out_n, + # Config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + RESCALE_OPT: tl.constexpr, +): + """ + Kernel demonstrating vote_ballot_sync pattern for conditional rescaling. + + For each block: + 1. Load alpha values (shape: BLOCK_M x 1) + 2. Use vote_ballot_sync to check if any alpha < 1.0 + 3. Conditionally scale acc values based on ballot result + + This pattern enables branch optimization when all threads in a warp + have alpha >= 1.0 (ballot result is 0), skipping the multiplication. + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Create pointers + acc_ptrs = acc_ptr + offs_m[:, None] * stride_acc_m + offs_n[None, :] * stride_acc_n + alpha_ptrs = alpha_ptr + offs_m * stride_alpha_m + out_ptrs = out_ptr + offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n + + # Masks for boundary checking + mask_m = offs_m < M + mask_n = offs_n < N + mask = mask_m[:, None] & mask_n[None, :] + + # Load accumulator (BLOCK_M x BLOCK_N) + acc = tl.load(acc_ptrs, mask=mask, other=0.0) + + # Load alpha (BLOCK_M x 1) - broadcast to BLOCK_M x BLOCK_N for operations + alpha_1 = tl.load(alpha_ptrs, mask=mask_m, other=1.0)[:, None] + + if RESCALE_OPT: + # Key pattern: Use vote_ballot_sync to check if ANY thread needs rescaling + # + # pred: tensor - True where alpha < 1.0 + # ballot_result: tensor - Warp ballot result + # - All elements contain the same warp-level ballot value + # - Non-zero means at least one thread has alpha_1 < 1.0 + # should_rescale: tensor - True if any rescaling needed + pred = alpha_1 < 1.0 + ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred) + should_rescale = ballot_result != 0 + + # Conditional scaling using tl.where + # When should_rescale is False (ballot_result == 0), skip multiplication + scaled_acc = _mul_f32x2(acc, alpha_1) + acc = tl.where(should_rescale, scaled_acc, acc) + else: + # Always rescale when optimization is disabled + acc = _mul_f32x2(acc, alpha_1) + + # Store result + tl.store(out_ptrs, acc, mask=mask) + + +@triton.jit +def vote_ballot_select_tmem_kernel( + # Pointers + acc_ptr, + alpha_ptr, + out_ptr, + # Shape + M: tl.constexpr, + N: tl.constexpr, + # Strides + stride_acc_m, + stride_acc_n, + stride_alpha_m, + stride_out_m, + stride_out_n, + # Config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + RESCALE_OPT: tl.constexpr, +): + """ + Extended kernel demonstrating vote_ballot_sync with TMEM operations. + + This kernel simulates the pattern from Flash Attention where: + 1. Data is loaded from global memory to TMEM (tensor memory) + 2. vote_ballot_sync determines if rescaling is needed + 3. Conditionally perform tmem_load -> compute -> tmem_store + + The goal is to convert the tl.where into an if-branch that guards + the entire tmem_load + computation + tmem_store sequence. + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Create pointers + acc_ptrs = acc_ptr + offs_m[:, None] * stride_acc_m + offs_n[None, :] * stride_acc_n + alpha_ptrs = alpha_ptr + offs_m * stride_alpha_m + out_ptrs = out_ptr + offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n + + # Masks for boundary checking + mask_m = offs_m < M + mask_n = offs_n < N + mask = mask_m[:, None] & mask_n[None, :] + + # Load accumulator (BLOCK_M x BLOCK_N) + acc = tl.load(acc_ptrs, mask=mask, other=0.0) + + # Load alpha (BLOCK_M x 1) + alpha_1 = tl.load(alpha_ptrs, mask=mask_m, other=1.0)[:, None] + + if RESCALE_OPT: + # Pattern from FA: vote_ballot to check rescaling need + pred = alpha_1 < 1.0 + ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred) + should_rescale = ballot_result != 0 + + # DESIRED OPTIMIZATION: + # Convert this tl.where into an if-branch at LLVM level: + # + # Current lowering: + # scaled_acc = acc * alpha_1 // Always computed + # result = select(should_rescale, scaled_acc, acc) + # + # Desired lowering (when should_rescale is warp-uniform): + # if (any_thread_in_warp(should_rescale)) { + # result = acc * alpha_1 + # } else { + # result = acc + # } + # + # Benefits: + # - When all alpha >= 1.0 in a warp, skip multiplication entirely + # - No warp divergence since ballot result is uniform across warp + + scaled_acc = _mul_f32x2(acc, alpha_1) + acc = tl.where(should_rescale, scaled_acc, acc) + else: + acc = _mul_f32x2(acc, alpha_1) + + # Store result + tl.store(out_ptrs, acc, mask=mask) + + +def test_vote_ballot_select(): + """Test the vote_ballot_sync + select pattern.""" + if not HAS_TLX: + print("SKIP: tlx not available") + return + + torch.manual_seed(42) + + # Test dimensions + M, N = 128, 64 + BLOCK_M, BLOCK_N = 32, 32 + + # Create test data + acc = torch.randn(M, N, dtype=torch.float32, device="cuda") + # Mix of alpha values: some < 1.0 (need rescaling), some >= 1.0 (no rescaling) + alpha = torch.ones(M, dtype=torch.float32, device="cuda") + # Set some values < 1.0 to trigger rescaling in some warps + alpha[:M // 4] = 0.5 # First quarter needs rescaling + alpha[M // 2:3 * M // 4] = 0.8 # Third quarter needs rescaling + + out_opt = torch.empty_like(acc) + out_ref = torch.empty_like(acc) + + # Grid + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + # Run with RESCALE_OPT=True (uses vote_ballot_sync) + vote_ballot_select_kernel[grid]( + acc, + alpha, + out_opt, + M, + N, + acc.stride(0), + acc.stride(1), + alpha.stride(0), + out_opt.stride(0), + out_opt.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + RESCALE_OPT=True, + ) + + # Run with RESCALE_OPT=False (reference, always rescales) + vote_ballot_select_kernel[grid]( + acc, + alpha, + out_ref, + M, + N, + acc.stride(0), + acc.stride(1), + alpha.stride(0), + out_ref.stride(0), + out_ref.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + RESCALE_OPT=False, + ) + + # Verify results match + torch.testing.assert_close(out_opt, out_ref, rtol=1e-5, atol=1e-5) + print("PASS: vote_ballot_select_kernel correctness verified") + + # Test edge cases + # Case 1: All alpha >= 1.0 (no rescaling needed) + alpha_no_rescale = torch.ones(M, dtype=torch.float32, device="cuda") + out_no_rescale = torch.empty_like(acc) + vote_ballot_select_kernel[grid]( + acc, + alpha_no_rescale, + out_no_rescale, + M, + N, + acc.stride(0), + acc.stride(1), + alpha_no_rescale.stride(0), + out_no_rescale.stride(0), + out_no_rescale.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + RESCALE_OPT=True, + ) + # When alpha=1.0, output should equal input (no rescaling) + torch.testing.assert_close(out_no_rescale, acc, rtol=1e-5, atol=1e-5) + print("PASS: No rescaling case (all alpha >= 1.0)") + + # Case 2: All alpha < 1.0 (all need rescaling) + alpha_all_rescale = torch.full((M, ), 0.5, dtype=torch.float32, device="cuda") + out_all_rescale = torch.empty_like(acc) + vote_ballot_select_kernel[grid]( + acc, + alpha_all_rescale, + out_all_rescale, + M, + N, + acc.stride(0), + acc.stride(1), + alpha_all_rescale.stride(0), + out_all_rescale.stride(0), + out_all_rescale.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + RESCALE_OPT=True, + ) + expected = acc * alpha_all_rescale[:, None] + torch.testing.assert_close(out_all_rescale, expected, rtol=1e-5, atol=1e-5) + print("PASS: All rescaling case (all alpha < 1.0)") + + print("\nAll tests passed!") + + +if __name__ == "__main__": + test_vote_ballot_select()