From b6b4cd37cd39ea7520d16520e72fe86d3f773090 Mon Sep 17 00:00:00 2001 From: Yinzuo Jiang Date: Sun, 8 Mar 2026 18:30:54 +0800 Subject: [PATCH 1/2] feat: deterministic topk also add cub stable radix sort and overflow handling Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Signed-off-by: Yinzuo Jiang --- benchmarks/bench_topk.py | 756 +++++++++++++++--- csrc/flashinfer_topk_binding.cu | 8 +- csrc/topk.cu | 14 +- flashinfer/topk.py | 67 +- include/flashinfer/topk.cuh | 1329 ++++++++++++++++++++++--------- tests/utils/test_topk.py | 823 +++++++++++++++++-- 6 files changed, 2440 insertions(+), 557 deletions(-) diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index de36e47f77..55557d021f 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -9,6 +9,8 @@ import argparse import os +from contextlib import contextmanager +from dataclasses import dataclass import numpy as np import torch @@ -25,6 +27,15 @@ def set_topk_algo(algo: str): os.environ["FLASHINFER_TOPK_ALGO"] = algo +def classify_benchmark_runtime_error(exc: RuntimeError) -> str | None: + message = str(exc).lower() + if "out of memory" in message: + return "OOM" + if "invalid" in message or "operation not supported" in message: + return "UNSUPPORTED" + return None + + # Try to import sgl_kernel for comparison try: import sgl_kernel @@ -34,69 +45,299 @@ def set_topk_algo(algo: str): HAS_SGL_KERNEL = False -def bench_top_k( - batch_size: int, - seq_len: int, - k: int, - dtype: torch.dtype = torch.float32, - compare_sglang: bool = False, -) -> dict: - """Benchmark basic top_k operation.""" - scores = torch.randn(batch_size, seq_len, device="cuda", dtype=dtype) +@contextmanager +def torch_deterministic_algorithms(enabled: bool): + """Temporarily set PyTorch deterministic algorithm mode.""" + previous = torch.are_deterministic_algorithms_enabled() + if previous != enabled: + torch.use_deterministic_algorithms(enabled) + try: + yield + finally: + if torch.are_deterministic_algorithms_enabled() != previous: + torch.use_deterministic_algorithms(previous) + - # FlashInfer top_k +def bench_median_ms(fn) -> float: measurements = bench_gpu_time( - lambda: flashinfer.top_k(scores, k), + fn, enable_cupti=True, dry_run_iters=10, repeat_iters=100, ) - fi_ms = np.median(measurements) + return float(np.median(measurements)) + + +def bench_flashinfer_modes( + run_flashinfer, deterministic: bool +) -> tuple[float, float | None]: + selected_ms = bench_median_ms(lambda: run_flashinfer(deterministic)) + nondeterministic_ms = ( + bench_median_ms(lambda: run_flashinfer(False)) if deterministic else None + ) + return selected_ms, nondeterministic_ms + + +def bench_top_k_from_scores( + scores: torch.Tensor, + k: int, + deterministic: bool = False, + compare_torch_deterministic: bool = False, + compare_sglang: bool = False, +) -> dict: + """Benchmark top-k on a pre-generated score tensor.""" + batch_size, seq_len = scores.shape + + fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes( + lambda deterministic_mode: flashinfer.top_k( + scores, + k, + deterministic=deterministic_mode, + ), + deterministic, + ) result = { "batch_size": batch_size, "seq_len": seq_len, "k": k, - "dtype": str(dtype), + "dtype": str(scores.dtype), "flashinfer_us": fi_ms * 1e3, } + if fi_nondeterministic_ms is not None: + result["flashinfer_nondeterministic_us"] = fi_nondeterministic_ms * 1e3 + result["deterministic_slowdown_vs_nondeterministic"] = ( + fi_ms / fi_nondeterministic_ms + ) - # Compare with torch.topk - measurements = bench_gpu_time( - lambda: torch.topk(scores, k, dim=-1), - enable_cupti=True, - dry_run_iters=10, - repeat_iters=100, - ) - torch_ms = np.median(measurements) + with torch_deterministic_algorithms(deterministic): + torch_ms = bench_median_ms(lambda: torch.topk(scores, k, dim=-1)) result["torch_us"] = torch_ms * 1e3 result["speedup_vs_torch"] = torch_ms / fi_ms + if compare_torch_deterministic and not deterministic: + with torch_deterministic_algorithms(True): + torch_det_ms = bench_median_ms(lambda: torch.topk(scores, k, dim=-1)) + result["torch_deterministic_us"] = torch_det_ms * 1e3 + result["speedup_vs_torch_deterministic"] = torch_det_ms / fi_ms + # SGLang comparison (only supports k=2048 and float32) - if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch.float32: + if ( + compare_sglang + and HAS_SGL_KERNEL + and k == 2048 + and scores.dtype == torch.float32 + ): lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") - measurements = bench_gpu_time( + sg_ms = bench_median_ms( lambda: sgl_kernel.fast_topk_v2(scores, lengths, k, row_starts=None), - enable_cupti=True, - dry_run_iters=10, - repeat_iters=100, ) - sg_ms = np.median(measurements) result["sglang_us"] = sg_ms * 1e3 result["speedup_vs_sglang"] = sg_ms / fi_ms return result +def generate_scores( + batch_size: int, + seq_len: int, + k: int, + dtype: torch.dtype, + input_pattern: str, +) -> torch.Tensor: + """Generate benchmark input scores with controllable tie patterns.""" + + if input_pattern == "random": + return torch.randn(batch_size, seq_len, device="cuda", dtype=dtype) + + if input_pattern in {"quantized_random", "relu_quantized"}: + base = torch.randn(batch_size, seq_len, device="cuda", dtype=torch.float32) + if input_pattern == "relu_quantized": + base = torch.relu(base) + scores = (torch.round(base * 32.0) / 32.0).to(dtype) + return scores + + if input_pattern == "tie_heavy": + pattern = ( + torch.arange(seq_len, device="cuda", dtype=torch.float32) % 64 + ) / 64.0 + return pattern.unsqueeze(0).expand(batch_size, -1).contiguous().to(dtype) + + if input_pattern == "pivot_tie": + # Severe tie at pivot: + # - majority entries are identical (1.0) + # - a small tail region is strictly larger (2.0) + # This creates truncation in == pivot region when k exceeds tail size. + scores = torch.ones(batch_size, seq_len, device="cuda", dtype=dtype) + gt_count = max(1, min(k // 4, seq_len // 8)) + scores[:, seq_len - gt_count :] = 2.0 + return scores + + raise ValueError(f"Unsupported input_pattern: {input_pattern}") + + +def generate_dsa_scores( + batch_size: int, + q_len: int, + seq_len: int, + dtype: torch.dtype, + input_pattern: str, + causal_chunk: bool, +) -> torch.Tensor: + """Generate DeepSeek DSA-like indexer score workload. + + Source context: + - DeepSeek-V3.2-Exp config uses index_topk=2048: + https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/inference/config_671B_v3.2.json + - Indexer runs topk over index_score last dim: + https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/inference/model.py + + Returns a 2D tensor with shape (batch_size * q_len, seq_len). + """ + rows = batch_size * q_len + + if input_pattern == "random": + scores = torch.randn(rows, seq_len, device="cuda", dtype=dtype) + elif input_pattern in {"quantized_random", "dsa_relu"}: + base = torch.randn(rows, seq_len, device="cuda", dtype=torch.float32) + if input_pattern == "dsa_relu": + base = torch.relu(base) + scores = (torch.round(base * 32.0) / 32.0).to(dtype) + else: + raise ValueError(f"Unsupported dsa input_pattern: {input_pattern}") + + if causal_chunk: + # Simulate prefill chunk near end of long context: + # each query row i can only attend [0, start_pos + i]. + start_pos = seq_len - q_len + lengths = torch.arange( + start_pos + 1, + start_pos + q_len + 1, + device="cuda", + dtype=torch.int32, + ).repeat(batch_size) + col = torch.arange(seq_len, device="cuda", dtype=torch.int32).unsqueeze(0) + invalid = col >= lengths.unsqueeze(1) + neg_inf = -torch.inf if dtype == torch.float32 else torch.finfo(dtype).min + scores = scores.masked_fill(invalid, neg_inf) + + return scores.contiguous() + + +@dataclass(frozen=True) +class DSATopKCase: + name: str + batch_size: int + q_len: int + seq_len: int + causal_chunk: bool + + +@dataclass(frozen=True) +class TopKCase: + name: str + batch_size: int + seq_len: int + k: int + + +def build_top_k_cases( + batch_sizes: list[int], + seq_lens: list[int], + k_values: list[int], +) -> list[TopKCase]: + cases: list[TopKCase] = [] + + for batch_size in batch_sizes: + for seq_len in seq_lens: + for k in k_values: + if k <= seq_len: + cases.append( + TopKCase( + name=f"grid_b{batch_size}_l{seq_len}_k{k}", + batch_size=batch_size, + seq_len=seq_len, + k=k, + ) + ) + + # These deterministic large-batch/long-vocab cases are not covered by the + # original grid but surfaced real correctness/performance differences. + cases.extend( + [ + TopKCase("stress_b2048_l131072_k1024", 2048, 131072, 1024), + TopKCase("stress_b4096_l200000_k1024", 4096, 200000, 1024), + ] + ) + + return cases + + +def bench_dsa_top_k( + batch_size: int, + q_len: int, + seq_len: int, + k: int, + dtype: torch.dtype = torch.bfloat16, + input_pattern: str = "dsa_relu", + deterministic: bool = False, + compare_torch_deterministic: bool = False, + compare_sglang: bool = False, + causal_chunk: bool = False, +) -> dict: + scores = generate_dsa_scores( + batch_size=batch_size, + q_len=q_len, + seq_len=seq_len, + dtype=dtype, + input_pattern=input_pattern, + causal_chunk=causal_chunk, + ) + result = bench_top_k_from_scores( + scores=scores, + k=k, + deterministic=deterministic, + compare_torch_deterministic=compare_torch_deterministic, + compare_sglang=compare_sglang, + ) + result["rows"] = batch_size * q_len + result["q_len"] = q_len + result["case_type"] = "prefill" if causal_chunk else "decode" + return result + + +def bench_top_k( + batch_size: int, + seq_len: int, + k: int, + dtype: torch.dtype = torch.float32, + input_pattern: str = "random", + deterministic: bool = False, + compare_torch_deterministic: bool = False, + compare_sglang: bool = False, +) -> dict: + """Benchmark basic top_k operation.""" + scores = generate_scores(batch_size, seq_len, k, dtype, input_pattern) + return bench_top_k_from_scores( + scores=scores, + k=k, + deterministic=deterministic, + compare_torch_deterministic=compare_torch_deterministic, + compare_sglang=compare_sglang, + ) + + def bench_page_table_transform( batch_size: int, seq_len: int, k: int, dtype: torch.dtype = torch.float32, + input_pattern: str = "random", + deterministic: bool = False, compare_sglang: bool = False, ) -> dict: """Benchmark fused top_k + page table transform.""" - scores = torch.randn(batch_size, seq_len, device="cuda", dtype=dtype) + scores = generate_scores(batch_size, seq_len, k, dtype, input_pattern) lengths = torch.full((batch_size,), seq_len, device="cuda", dtype=torch.int32) src_page_table = ( torch.arange(seq_len, device="cuda", dtype=torch.int32) @@ -105,16 +346,16 @@ def bench_page_table_transform( .contiguous() ) - # FlashInfer - measurements = bench_gpu_time( - lambda: flashinfer.top_k_page_table_transform( - scores, src_page_table, lengths, k + fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes( + lambda deterministic_mode: flashinfer.top_k_page_table_transform( + scores, + src_page_table, + lengths, + k, + deterministic=deterministic_mode, ), - enable_cupti=True, - dry_run_iters=10, - repeat_iters=100, + deterministic, ) - fi_ms = np.median(measurements) result = { "batch_size": batch_size, @@ -123,19 +364,20 @@ def bench_page_table_transform( "dtype": str(dtype), "flashinfer_us": fi_ms * 1e3, } + if fi_nondeterministic_ms is not None: + result["flashinfer_nondeterministic_us"] = fi_nondeterministic_ms * 1e3 + result["deterministic_slowdown_vs_nondeterministic"] = ( + fi_ms / fi_nondeterministic_ms + ) # SGLang comparison (only supports k=2048 and float32) if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch.float32: cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") - measurements = bench_gpu_time( + sg_ms = bench_median_ms( lambda: sgl_kernel.fast_topk_transform_fused( scores, lengths, src_page_table, cu_seqlens_q, k ), - enable_cupti=True, - dry_run_iters=10, - repeat_iters=100, ) - sg_ms = np.median(measurements) result["sglang_us"] = sg_ms * 1e3 result["speedup_vs_sglang"] = sg_ms / fi_ms @@ -147,23 +389,27 @@ def bench_ragged_transform( seq_len: int, k: int, dtype: torch.dtype = torch.float32, + input_pattern: str = "random", + deterministic: bool = False, compare_sglang: bool = False, ) -> dict: """Benchmark fused top_k + ragged index transform.""" - scores = torch.randn(batch_size, seq_len, device="cuda", dtype=dtype) + scores = generate_scores(batch_size, seq_len, k, dtype, input_pattern) lengths = torch.full((batch_size,), seq_len, device="cuda", dtype=torch.int32) offsets = torch.arange( 0, batch_size * seq_len, seq_len, device="cuda", dtype=torch.int32 ) - # FlashInfer - measurements = bench_gpu_time( - lambda: flashinfer.top_k_ragged_transform(scores, offsets, lengths, k), - enable_cupti=True, - dry_run_iters=10, - repeat_iters=100, + fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes( + lambda deterministic_mode: flashinfer.top_k_ragged_transform( + scores, + offsets, + lengths, + k, + deterministic=deterministic_mode, + ), + deterministic, ) - fi_ms = np.median(measurements) result = { "batch_size": batch_size, @@ -172,18 +418,19 @@ def bench_ragged_transform( "dtype": str(dtype), "flashinfer_us": fi_ms * 1e3, } + if fi_nondeterministic_ms is not None: + result["flashinfer_nondeterministic_us"] = fi_nondeterministic_ms * 1e3 + result["deterministic_slowdown_vs_nondeterministic"] = ( + fi_ms / fi_nondeterministic_ms + ) # SGLang comparison (only supports k=2048 and float32) if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch.float32: - measurements = bench_gpu_time( + sg_ms = bench_median_ms( lambda: sgl_kernel.fast_topk_transform_ragged_fused( scores, lengths, offsets, k ), - enable_cupti=True, - dry_run_iters=10, - repeat_iters=100, ) - sg_ms = np.median(measurements) result["sglang_us"] = sg_ms * 1e3 result["speedup_vs_sglang"] = sg_ms / fi_ms @@ -214,7 +461,7 @@ def main(): ) parser.add_argument( "--op", - choices=["all", "top_k", "page_table", "ragged"], + choices=["all", "top_k", "dsa_topk", "page_table", "ragged"], default="all", help="Which operation to benchmark", ) @@ -229,6 +476,49 @@ def main(): action="store_true", help="Compare multi-CTA vs filtered algorithms", ) + parser.add_argument( + "--deterministic", + action="store_true", + help="Enable deterministic mode for FlashInfer top-k kernels", + ) + parser.add_argument( + "--compare-torch-deterministic", + action="store_true", + help="Also benchmark torch.topk under deterministic algorithm mode", + ) + parser.add_argument( + "--input-pattern", + choices=[ + "random", + "quantized_random", + "relu_quantized", + "tie_heavy", + "pivot_tie", + ], + default="random", + help=( + "Input score pattern: random | quantized_random | relu_quantized | " + "tie_heavy | pivot_tie" + ), + ) + parser.add_argument( + "--dsa-input-pattern", + choices=["random", "quantized_random", "dsa_relu"], + default="dsa_relu", + help="DSA top-k input pattern: random | quantized_random | dsa_relu", + ) + parser.add_argument( + "--dsa-case", + choices=["all", "decode", "prefill"], + default="all", + help="DSA case group: all | decode | prefill", + ) + parser.add_argument( + "--dsa-topk", + type=int, + default=2048, + help="Top-k for DSA workload (default: 2048, matching DeepSeek DSA config)", + ) args = parser.parse_args() dtype = parse_dtype(args.dtype) @@ -239,16 +529,27 @@ def main(): # Test configurations batch_sizes = [1, 16, 64, 256] - seq_lens = [4096, 16384, 65536, 131072, 262144, 524288] + seq_lens = [256, 512, 1024, 2048, 4096, 16384, 65536, 131072, 262144, 524288] k_values = [256, 512, 1024, 2048, 4096] + top_k_cases = build_top_k_cases( + batch_sizes=batch_sizes, + seq_lens=seq_lens, + k_values=k_values, + ) dtype_str = args.dtype.upper() # Algorithm comparison mode if args.compare_algorithms: + if args.deterministic: + print( + "ERROR: --compare-algorithms is only meaningful with non-deterministic mode" + ) + return print("=" * 100) print( - f"Algorithm comparison: Multi-CTA vs Filtered (dtype={dtype_str}, k=2048)" + "Algorithm comparison: Multi-CTA vs Filtered " + f"(dtype={dtype_str}, k=2048, pattern={args.input_pattern})" ) print("=" * 100) print( @@ -265,13 +566,15 @@ def main(): # Benchmark Multi-CTA set_topk_algo("multi_cta") result_mc = bench_page_table_transform( - batch_size, seq_len, k, dtype + batch_size, seq_len, k, dtype, args.input_pattern ) mc_us = result_mc["flashinfer_us"] # Benchmark Filtered set_topk_algo("filtered") - result_f = bench_page_table_transform(batch_size, seq_len, k, dtype) + result_f = bench_page_table_transform( + batch_size, seq_len, k, dtype, args.input_pattern + ) f_us = result_f["flashinfer_us"] # Reset to auto @@ -285,70 +588,252 @@ def main(): f"{winner:>8} {speedup:.2f}x" ) except RuntimeError as e: - if "out of memory" in str(e): + error_label = classify_benchmark_runtime_error(e) + if error_label == "OOM": print(f"{batch_size:>6} {seq_len:>10} | OOM") torch.cuda.empty_cache() + elif error_label == "UNSUPPORTED": + print(f"{batch_size:>6} {seq_len:>10} | UNSUPPORTED") else: raise return if args.op in ["all", "top_k"]: print("=" * 100) - print(f"top_k: Basic radix-based top-k selection (dtype={dtype_str})") + print( + "top_k: Basic radix-based top-k selection " + f"(dtype={dtype_str}, deterministic={args.deterministic}, " + f"pattern={args.input_pattern})" + ) if args.compare_sglang: print("NOTE: SGLang only supports k=2048 and float32") + if args.deterministic: + print( + "NOTE: deterministic mode also benchmarks FlashInfer(non-det) " + "for direct comparison" + ) + if args.compare_torch_deterministic: + print( + "NOTE: torch.det means torch.topk with torch.use_deterministic_algorithms(True)" + ) + elif args.deterministic: + print( + "NOTE: torch column uses torch.topk with " + "torch.use_deterministic_algorithms(True)" + ) + print( + "NOTE: default top-k sweep includes two extra large-batch/long-vocab " + "stress cases beyond the original grid" + ) print("=" * 100) - header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}" + if args.deterministic: + header = ( + f"{'batch':>6} {'seq_len':>10} {'k':>6} | " + f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11} " + f"{'torch.det':>12} {'Speedup':>10}" + ) + else: + header = ( + f"{'batch':>6} {'seq_len':>10} {'k':>6} | " + f"{'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}" + ) + if args.compare_torch_deterministic and not args.deterministic: + header += f" {'torch.det':>12} {'Speedup':>10}" if args.compare_sglang: header += f" {'SGLang':>12} {'Speedup':>10}" print(header) - print("-" * (70 if not args.compare_sglang else 90)) + divider_len = 96 if args.deterministic else 72 + if args.compare_torch_deterministic and not args.deterministic: + divider_len += 24 + if args.compare_sglang: + divider_len += 24 + print("-" * divider_len) + + for case in top_k_cases: + try: + result = bench_top_k( + case.batch_size, + case.seq_len, + case.k, + dtype, + input_pattern=args.input_pattern, + deterministic=args.deterministic, + compare_torch_deterministic=args.compare_torch_deterministic, + compare_sglang=args.compare_sglang, + ) + if args.deterministic: + line = ( + f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " + f"{result['flashinfer_nondeterministic_us']:>10.2f}us " + f"{result['flashinfer_us']:>12.2f}us " + f"{result['deterministic_slowdown_vs_nondeterministic']:>10.2f}x " + f"{result['torch_us']:>10.2f}us " + f"{result['speedup_vs_torch']:>9.2f}x" + ) + else: + line = ( + f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " + f"{result['flashinfer_us']:>12.2f}us {result['torch_us']:>10.2f}us " + f"{result['speedup_vs_torch']:>9.2f}x" + ) + if "torch_deterministic_us" in result: + line += ( + f" {result['torch_deterministic_us']:>10.2f}us " + f"{result['speedup_vs_torch_deterministic']:>9.2f}x" + ) + if "sglang_us" in result: + line += ( + f" {result['sglang_us']:>10.2f}us " + f"{result['speedup_vs_sglang']:>9.2f}x" + ) + elif args.compare_sglang and case.k == 2048: + line += " (SGLang error)" + print(line) + except RuntimeError as e: + error_label = classify_benchmark_runtime_error(e) + if error_label is not None: + print( + f"{case.batch_size:>6} {case.seq_len:>10} {case.k:>6} | {error_label}" + ) + torch.cuda.empty_cache() + else: + raise - for batch_size in batch_sizes: - for seq_len in seq_lens: - for k in k_values: - if k > seq_len: - continue - try: - result = bench_top_k( - batch_size, - seq_len, - k, - dtype, - compare_sglang=args.compare_sglang, - ) - line = ( - f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " - f"{result['flashinfer_us']:>10.2f}us {result['torch_us']:>10.2f}us " - f"{result['speedup_vs_torch']:>9.2f}x" - ) - if "sglang_us" in result: - line += f" {result['sglang_us']:>10.2f}us {result['speedup_vs_sglang']:>9.2f}x" - elif args.compare_sglang and k == 2048: - line += " (SGLang error)" - print(line) - except RuntimeError as e: - if "out of memory" in str(e): - print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM") - torch.cuda.empty_cache() - else: - raise + if args.op in ["all", "dsa_topk"]: + print("\n" + "=" * 100) + print( + "dsa_topk: DeepSeek DSA-like indexer top-k workload " + f"(dtype={dtype_str}, deterministic={args.deterministic}, " + f"dsa_pattern={args.dsa_input_pattern}, k={args.dsa_topk})" + ) + if args.deterministic: + print( + "NOTE: deterministic mode also benchmarks FlashInfer(non-det) " + "for direct comparison" + ) + if args.compare_torch_deterministic: + print( + "NOTE: torch.det means torch.topk with torch.use_deterministic_algorithms(True)" + ) + elif args.deterministic: + print( + "NOTE: torch column uses torch.topk with " + "torch.use_deterministic_algorithms(True)" + ) + print("=" * 100) + + if args.deterministic: + header = ( + f"{'case':>24} {'rows':>8} {'seq_len':>10} {'k':>6} | " + f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11} " + f"{'torch.det':>12} {'Speedup':>10}" + ) + else: + header = ( + f"{'case':>24} {'rows':>8} {'seq_len':>10} {'k':>6} | " + f"{'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}" + ) + if args.compare_torch_deterministic and not args.deterministic: + header += f" {'torch.det':>12} {'Speedup':>10}" + print(header) + divider_len = 110 if args.deterministic else 86 + if args.compare_torch_deterministic and not args.deterministic: + divider_len += 24 + print("-" * divider_len) + + dsa_cases = [ + # DeepSeek Sparse Attention proxy cases: + # - decode: q_len=1 + # - prefill chunk: q_len>1 with causal availability growth + # Ref: https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf + DSATopKCase("decode_b1_q1_l128k", 1, 1, 131072, False), + DSATopKCase("decode_b8_q1_l64k", 8, 1, 65536, False), + DSATopKCase("decode_b32_q1_l128k", 32, 1, 131072, False), + DSATopKCase("prefill_b1_q128_l128k", 1, 128, 131072, True), + ] + + for case in dsa_cases: + if args.dsa_case == "decode" and case.causal_chunk: + continue + if args.dsa_case == "prefill" and not case.causal_chunk: + continue + if args.dsa_topk > case.seq_len: + continue + try: + result = bench_dsa_top_k( + batch_size=case.batch_size, + q_len=case.q_len, + seq_len=case.seq_len, + k=args.dsa_topk, + dtype=dtype, + input_pattern=args.dsa_input_pattern, + deterministic=args.deterministic, + compare_torch_deterministic=args.compare_torch_deterministic, + compare_sglang=False, + causal_chunk=case.causal_chunk, + ) + if args.deterministic: + line = ( + f"{case.name:>24} {result['rows']:>8} {result['seq_len']:>10} {result['k']:>6} | " + f"{result['flashinfer_nondeterministic_us']:>10.2f}us " + f"{result['flashinfer_us']:>12.2f}us " + f"{result['deterministic_slowdown_vs_nondeterministic']:>10.2f}x " + f"{result['torch_us']:>10.2f}us " + f"{result['speedup_vs_torch']:>9.2f}x" + ) + else: + line = ( + f"{case.name:>24} {result['rows']:>8} {result['seq_len']:>10} {result['k']:>6} | " + f"{result['flashinfer_us']:>10.2f}us {result['torch_us']:>10.2f}us " + f"{result['speedup_vs_torch']:>9.2f}x" + ) + if "torch_deterministic_us" in result: + line += ( + f" {result['torch_deterministic_us']:>10.2f}us " + f"{result['speedup_vs_torch_deterministic']:>9.2f}x" + ) + print(line) + except RuntimeError as e: + error_label = classify_benchmark_runtime_error(e) + if error_label is not None: + print( + f"{case.name:>24} {case.batch_size * case.q_len:>8} {case.seq_len:>10} " + f"{args.dsa_topk:>6} | {error_label}" + ) + torch.cuda.empty_cache() + else: + raise if args.op in ["all", "page_table"]: print("\n" + "=" * 100) print( - f"top_k_page_table_transform: Fused top-k + page table gather (dtype={dtype_str})" + "top_k_page_table_transform: Fused top-k + page table gather " + f"(dtype={dtype_str}, deterministic={args.deterministic}, pattern={args.input_pattern})" ) if args.compare_sglang: print("NOTE: SGLang only supports k=2048 and float32") + if args.deterministic: + print( + "NOTE: deterministic mode also benchmarks FlashInfer(non-det) " + "for direct comparison" + ) print("=" * 100) - header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}" + if args.deterministic: + header = ( + f"{'batch':>6} {'seq_len':>10} {'k':>6} | " + f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}" + ) + else: + header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}" if args.compare_sglang: header += f" {'SGLang':>12} {'Speedup':>10}" print(header) - print("-" * (70 if not args.compare_sglang else 90)) + divider_len = 87 if args.deterministic else 70 + if args.compare_sglang: + divider_len += 20 + print("-" * divider_len) for batch_size in batch_sizes: for seq_len in seq_lens: @@ -361,20 +846,36 @@ def main(): seq_len, k, dtype, + input_pattern=args.input_pattern, + deterministic=args.deterministic, compare_sglang=args.compare_sglang, ) - line = ( - f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " - f"{result['flashinfer_us']:>10.2f}us" - ) + if args.deterministic: + line = ( + f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " + f"{result['flashinfer_nondeterministic_us']:>10.2f}us " + f"{result['flashinfer_us']:>12.2f}us " + f"{result['deterministic_slowdown_vs_nondeterministic']:>10.2f}x" + ) + else: + line = ( + f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " + f"{result['flashinfer_us']:>10.2f}us" + ) if "sglang_us" in result: - line += f" {result['sglang_us']:>10.2f}us {result['speedup_vs_sglang']:>9.2f}x" + line += ( + f" {result['sglang_us']:>10.2f}us " + f"{result['speedup_vs_sglang']:>9.2f}x" + ) elif args.compare_sglang and k == 2048: line += " (SGLang error)" print(line) except RuntimeError as e: - if "out of memory" in str(e): - print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM") + error_label = classify_benchmark_runtime_error(e) + if error_label is not None: + print( + f"{batch_size:>6} {seq_len:>10} {k:>6} | {error_label}" + ) torch.cuda.empty_cache() else: raise @@ -382,17 +883,32 @@ def main(): if args.op in ["all", "ragged"]: print("\n" + "=" * 100) print( - f"top_k_ragged_transform: Fused top-k + ragged index transform (dtype={dtype_str})" + "top_k_ragged_transform: Fused top-k + ragged index transform " + f"(dtype={dtype_str}, deterministic={args.deterministic}, pattern={args.input_pattern})" ) if args.compare_sglang: print("NOTE: SGLang only supports k=2048 and float32") + if args.deterministic: + print( + "NOTE: deterministic mode also benchmarks FlashInfer(non-det) " + "for direct comparison" + ) print("=" * 100) - header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}" + if args.deterministic: + header = ( + f"{'batch':>6} {'seq_len':>10} {'k':>6} | " + f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}" + ) + else: + header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}" if args.compare_sglang: header += f" {'SGLang':>12} {'Speedup':>10}" print(header) - print("-" * (70 if not args.compare_sglang else 90)) + divider_len = 87 if args.deterministic else 70 + if args.compare_sglang: + divider_len += 20 + print("-" * divider_len) for batch_size in batch_sizes: for seq_len in seq_lens: @@ -405,20 +921,36 @@ def main(): seq_len, k, dtype, + input_pattern=args.input_pattern, + deterministic=args.deterministic, compare_sglang=args.compare_sglang, ) - line = ( - f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " - f"{result['flashinfer_us']:>10.2f}us" - ) + if args.deterministic: + line = ( + f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " + f"{result['flashinfer_nondeterministic_us']:>10.2f}us " + f"{result['flashinfer_us']:>12.2f}us " + f"{result['deterministic_slowdown_vs_nondeterministic']:>10.2f}x" + ) + else: + line = ( + f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " + f"{result['flashinfer_us']:>10.2f}us" + ) if "sglang_us" in result: - line += f" {result['sglang_us']:>10.2f}us {result['speedup_vs_sglang']:>9.2f}x" + line += ( + f" {result['sglang_us']:>10.2f}us " + f"{result['speedup_vs_sglang']:>9.2f}x" + ) elif args.compare_sglang and k == 2048: line += " (SGLang error)" print(line) except RuntimeError as e: - if "out of memory" in str(e): - print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM") + error_label = classify_benchmark_runtime_error(e) + if error_label is not None: + print( + f"{batch_size:>6} {seq_len:>10} {k:>6} | {error_label}" + ) torch.cuda.empty_cache() else: raise diff --git a/csrc/flashinfer_topk_binding.cu b/csrc/flashinfer_topk_binding.cu index 090c340631..44ce7b5349 100644 --- a/csrc/flashinfer_topk_binding.cu +++ b/csrc/flashinfer_topk_binding.cu @@ -18,16 +18,18 @@ using tvm::ffi::Optional; void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, - Optional maybe_row_states_buffer, int64_t top_k); + Optional maybe_row_states_buffer, int64_t top_k, bool sorted_output, + bool deterministic); void radix_topk_page_table_transform(TensorView input, TensorView output_page_table, TensorView src_page_table, Optional maybe_row_to_batch, TensorView lengths, - Optional maybe_row_states_buffer, int64_t top_k); + Optional maybe_row_states_buffer, int64_t top_k, + bool deterministic); void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, TensorView lengths, Optional maybe_row_states_buffer, - int64_t top_k); + int64_t top_k, bool deterministic); bool can_implement_filtered_topk(); diff --git a/csrc/topk.cu b/csrc/topk.cu index 57690c1608..64f661ac3c 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -23,7 +23,8 @@ using namespace flashinfer; using tvm::ffi::Optional; void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, - Optional maybe_row_states_buffer, int64_t top_k) { + Optional maybe_row_states_buffer, int64_t top_k, bool sorted_output, + bool deterministic) { CHECK_INPUT(input); CHECK_INPUT(output_indices); CHECK_INPUT(output_values); @@ -52,7 +53,7 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v status = sampling::TopKDispatch( static_cast(input.data_ptr()), static_cast(output_indices.data_ptr()), static_cast(output_values.data_ptr()), batch_size, static_cast(top_k), d, - row_states_ptr, stream); + row_states_ptr, sorted_output, deterministic, stream); return true; }); @@ -63,7 +64,8 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v void radix_topk_page_table_transform(TensorView input, TensorView output_page_table, TensorView src_page_table, Optional maybe_row_to_batch, TensorView lengths, - Optional maybe_row_states_buffer, int64_t top_k) { + Optional maybe_row_states_buffer, int64_t top_k, + bool deterministic) { CHECK_INPUT(input); CHECK_INPUT(output_page_table); CHECK_INPUT(src_page_table); @@ -100,7 +102,7 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta static_cast(input.data_ptr()), static_cast(output_page_table.data_ptr()), static_cast(src_page_table.data_ptr()), src_stride, row_to_batch_ptr, static_cast(lengths.data_ptr()), num_rows, static_cast(top_k), max_len, - row_states_ptr, stream); + row_states_ptr, deterministic, stream); return true; }); @@ -110,7 +112,7 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, TensorView lengths, Optional maybe_row_states_buffer, - int64_t top_k) { + int64_t top_k, bool deterministic) { CHECK_INPUT(input); CHECK_INPUT(output_indices); CHECK_INPUT(offsets); @@ -140,7 +142,7 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te status = sampling::TopKRaggedTransformDispatch( static_cast(input.data_ptr()), static_cast(output_indices.data_ptr()), static_cast(offsets.data_ptr()), static_cast(lengths.data_ptr()), - num_rows, static_cast(top_k), max_len, row_states_ptr, stream); + num_rows, static_cast(top_k), max_len, row_states_ptr, deterministic, stream); return true; }); diff --git a/flashinfer/topk.py b/flashinfer/topk.py index d1cfb754be..79f6eb131b 100644 --- a/flashinfer/topk.py +++ b/flashinfer/topk.py @@ -35,6 +35,8 @@ def get_topk_module(): def radix_topk( input: torch.Tensor, top_k: int, + sorted_output: bool, + deterministic: bool, row_states_buffer: Optional[torch.Tensor], output_values: torch.Tensor, ) -> torch.Tensor: @@ -48,7 +50,13 @@ def radix_topk( batch_size, top_k, dtype=torch.int32, device=device ) module.radix_topk( - input, output_indices, output_values, row_states_buffer, top_k + input, + output_indices, + output_values, + row_states_buffer, + top_k, + sorted_output, + deterministic, ) return output_indices @@ -56,6 +64,8 @@ def radix_topk( def _fake_radix_topk( input: torch.Tensor, top_k: int, + sorted_output: bool, + deterministic: bool, row_states_buffer: Optional[torch.Tensor], output_values: torch.Tensor, ) -> torch.Tensor: @@ -74,6 +84,7 @@ def radix_topk_page_table_transform( lengths: torch.Tensor, row_states_buffer: Optional[torch.Tensor], top_k: int, + deterministic: bool, ) -> None: assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], ( f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16" @@ -86,6 +97,7 @@ def radix_topk_page_table_transform( lengths, row_states_buffer, top_k, + deterministic, ) @register_fake_op("flashinfer::radix_topk_page_table_transform") @@ -97,6 +109,7 @@ def _fake_radix_topk_page_table_transform( lengths: torch.Tensor, row_states_buffer: Optional[torch.Tensor], top_k: int, + deterministic: bool, ) -> None: pass @@ -111,12 +124,19 @@ def radix_topk_ragged_transform( lengths: torch.Tensor, row_states_buffer: Optional[torch.Tensor], top_k: int, + deterministic: bool, ) -> None: assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], ( f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16" ) module.radix_topk_ragged_transform( - input, output_indices, offsets, lengths, row_states_buffer, top_k + input, + output_indices, + offsets, + lengths, + row_states_buffer, + top_k, + deterministic, ) @register_fake_op("flashinfer::radix_topk_ragged_transform") @@ -127,6 +147,7 @@ def _fake_radix_topk_ragged_transform( lengths: torch.Tensor, row_states_buffer: Optional[torch.Tensor], top_k: int, + deterministic: bool, ) -> None: pass @@ -157,6 +178,7 @@ def top_k( input: torch.Tensor, k: int, sorted: bool = False, + deterministic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Radix-based Top-K selection. @@ -177,6 +199,12 @@ def top_k( sorted : bool, optional If True, the returned top-k elements will be sorted in descending order. Default is False (unsorted, which is faster). + deterministic : bool, optional + If True, uses deterministic mode. + Default is False (non-deterministic, which is faster). + + Deterministic mode guarantees repeatable FlashInfer output ordering for + the selected top-k set on a fixed input and system. Returns ------- @@ -213,6 +241,10 @@ def top_k( >>> values_sorted, indices_sorted = flashinfer.top_k(logits, k, sorted=True) >>> # Values are now in descending order within each row + Deterministic mode (bitwise-reproducible output): + + >>> values, indices = flashinfer.top_k(logits, k, deterministic=True) + See Also -------- torch.topk : PyTorch's built-in top-k function @@ -223,7 +255,8 @@ def top_k( device = input.device # Allocate row_states buffer for multi-CTA path - # 1MB is enough for any reasonable GPU (covers up to ~500 groups) + # 1MB is enough for any reasonable GPU (covers up to ~200 groups for deterministic + # mode and ~300 groups for non-deterministic mode) row_states_buffer: Optional[torch.Tensor] = _get_cache_buf( f"radix_topk_row_states_{input.device}", 1024 * 1024, # 1MB @@ -234,17 +267,20 @@ def top_k( # Allocate output_values for kernel to write directly output_values = torch.empty(batch_size, k, dtype=input.dtype, device=device) - # Get indices using radix-based selection + # For deterministic + sorted + k <= 2048: CUDA handles the stable value sort on device. + sorted_cuda = sorted and deterministic and k <= 2048 indices_int32 = get_topk_module().radix_topk( - input, k, row_states_buffer, output_values + input, k, sorted_cuda, deterministic, row_states_buffer, output_values ) # Convert to int64 for compatibility indices = indices_int32.long() - if sorted: + if sorted and not sorted_cuda: # Sort within each row by value (descending) - sorted_values, sort_indices = torch.sort(output_values, dim=-1, descending=True) + sorted_values, sort_indices = torch.sort( + output_values, dim=-1, descending=True, stable=deterministic + ) sorted_indices = torch.gather(indices, dim=-1, index=sort_indices) return sorted_values, sorted_indices @@ -262,6 +298,7 @@ def top_k_page_table_transform( lengths: torch.Tensor, k: int, row_to_batch: Optional[torch.Tensor] = None, + deterministic: bool = False, ) -> torch.Tensor: r"""Fused Top-K selection + Page Table Transform for sparse attention. @@ -290,6 +327,9 @@ def top_k_page_table_transform( Mapping from row index to batch index of shape ``(num_rows,)`` with dtype ``int32``. If None, uses 1:1 mapping (row_idx == batch_idx). Default is None. + deterministic : bool, optional + If True, uses deterministic mode. + Default is False (non-deterministic, which is faster). Returns ------- @@ -340,6 +380,7 @@ def top_k_page_table_transform( lengths, row_states_buffer, k, + deterministic, ) return output_page_table @@ -351,6 +392,7 @@ def top_k_ragged_transform( offsets: torch.Tensor, lengths: torch.Tensor, k: int, + deterministic: bool = False, ) -> torch.Tensor: r"""Fused Top-K selection + Ragged Index Transform for sparse attention. @@ -372,6 +414,9 @@ def top_k_ragged_transform( Actual KV lengths per row of shape ``(num_rows,)`` with dtype ``int32``. k : int Number of top elements to select from each row. + deterministic : bool, optional + If True, uses deterministic mode. + Default is False (non-deterministic, which is faster). Returns ------- @@ -416,7 +461,13 @@ def top_k_ragged_transform( output_indices = torch.empty(num_rows, k, dtype=torch.int32, device=device) get_topk_module().radix_topk_ragged_transform( - input, output_indices, offsets, lengths, row_states_buffer, k + input, + output_indices, + offsets, + lengths, + row_states_buffer, + k, + deterministic, ) return output_indices diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 468ea55495..62071097b1 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -18,7 +18,9 @@ #include +#include #include +#include #include #include #include @@ -30,6 +32,25 @@ namespace flashinfer { namespace sampling { +template +inline size_t GetRadixTopKAvailableOrderedSmemBytes(size_t max_smem_per_block, + size_t fixed_smem_aligned, + bool reserve_launch_headroom) { + using RadixTopKDetBlockScanT = + cub::BlockScan; + constexpr size_t RADIX_TOPK_DETERMINISTIC_BLOCK_SCAN_SMEM = + sizeof(typename RadixTopKDetBlockScanT::TempStorage); + constexpr size_t RADIX_TOPK_LAUNCH_SMEM_HEADROOM = 2 * RADIX_TOPK_DETERMINISTIC_BLOCK_SCAN_SMEM; + const size_t launch_headroom = + reserve_launch_headroom ? RADIX_TOPK_LAUNCH_SMEM_HEADROOM : size_t(0); + if (max_smem_per_block <= fixed_smem_aligned + launch_headroom) { + return 0; + } + // Reserve enough launch-time headroom for deterministic radix kernels that + // instantiate additional static shared scratch such as BlockScan temp storage. + return max_smem_per_block - fixed_smem_aligned - launch_headroom; +} + // ============================================================================ // RadixTopK Type Traits - supports float, half, and bfloat16 // OrderedType: uint32_t for float, uint16_t for half/bf16 @@ -182,7 +203,90 @@ struct RadixRowState { float sum_topk; // For RenormProb: sum of top-k elements }; +constexpr uint32_t RADIX_TOPK_MAX_DETERMINISTIC_CTAS_PER_GROUP = 256; + +struct RadixDeterministicCollectScratch { + uint32_t gt_count[RADIX_TOPK_MAX_DETERMINISTIC_CTAS_PER_GROUP]; + uint32_t eq_count[RADIX_TOPK_MAX_DETERMINISTIC_CTAS_PER_GROUP]; +}; + +inline RadixDeterministicCollectScratch* MaybeGetRadixDeterministicCollectScratchBuffer( + RadixRowState* row_states_buffer, uint32_t num_groups, bool single_cta, bool deterministic) { + return (single_cta || !deterministic || row_states_buffer == nullptr) + ? nullptr + : reinterpret_cast(row_states_buffer + num_groups); +} + // ==================== Common Device Functions for Radix Top-K ==================== +/*! + * \brief Software barrier across all CTAs in the same radix group. + * + * Each CTA contributes exactly one arrival via tx==0, then waits until the + * group-wide arrival counter reaches the current phase target. + * + * \param state Per-group radix row state that owns the arrival counter + * \param barrier_phase Current software-barrier phase for this CTA group + * \param ctas_per_group Number of CTAs participating in the group barrier + * \param tx Thread index within the block + */ +__device__ __forceinline__ void AdvanceRadixGroupBarrier(RadixRowState* state, int& barrier_phase, + uint32_t ctas_per_group, uint32_t tx) { + if (tx == 0) { + red_release(&state->arrival_counter, 1); + } + int target = (barrier_phase + 1) * ctas_per_group; + wait_ge(&state->arrival_counter, target, tx); + barrier_phase++; + __syncthreads(); +} + +/*! + * \brief Deterministically collect thread-strided matches with a full CTA scan. + * + * Threads traverse indices in the fixed order `tx, tx + BLOCK_THREADS, ...`, compute + * per-thread match counts over the full strided chain, exclusive-scan those counts across + * the CTA, then emit matches in that same deterministic thread-strided order. + * + * \tparam BLOCK_THREADS Number of threads in the CTA + * \param tx Thread index within the CTA + * \param length Number of elements to scan + * \param scan_temp_storage CUB BlockScan temp storage reused by the caller + * \param is_selected Predicate over the thread-strided index + * \param emit_limit Maximum number of selected elements to emit + * \param emit_selected Callback invoked as emit_selected(index, local_pos) + */ +template +__device__ __forceinline__ void DeterministicThreadStridedCollect(uint32_t tx, uint32_t length, + TempStorage& scan_temp_storage, + Predicate is_selected, + uint32_t emit_limit, + EmitFn emit_selected) { + using BlockScan = cub::BlockScan; + + uint32_t thread_local_selected_count = 0; + for (uint32_t i = tx; i < length; i += BLOCK_THREADS) { + thread_local_selected_count += static_cast(is_selected(i)); + } + + uint32_t thread_local_selected_prefix = 0; + BlockScan(scan_temp_storage) + .ExclusiveSum(thread_local_selected_count, thread_local_selected_prefix); + + if (thread_local_selected_count > 0 && thread_local_selected_prefix < emit_limit) { + uint32_t thread_local_emit_pos = thread_local_selected_prefix; + const uint32_t thread_local_emit_end = + min(thread_local_selected_prefix + thread_local_selected_count, emit_limit); + for (uint32_t i = tx; i < length; i += BLOCK_THREADS) { + if (is_selected(i)) { + emit_selected(i, thread_local_emit_pos); + if (++thread_local_emit_pos == thread_local_emit_end) { + break; + } + } + } + } + __syncthreads(); +} /*! * \brief Compute suffix sum in shared memory using parallel reduction. @@ -363,13 +467,7 @@ __device__ __forceinline__ void RadixSelectOneRound( } // Barrier: wait for all CTAs to finish atomicAdd and clearing - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - int target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; - __syncthreads(); + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); // Read current histogram (after barrier, all atomicAdds are complete) for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { @@ -462,12 +560,12 @@ __device__ __forceinline__ void LoadToSharedOrdered(const DType* input, * \param iter Current iteration (for triple-buffer indexing) * \return The pivot value in ordered representation */ -template +template __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory( const OrderedType* shared_ordered, uint32_t actual_chunk_size, uint32_t k, uint32_t* local_histogram, uint32_t* suffix_sum, uint32_t* shared_scalars, RadixRowState* state, int& barrier_phase, uint32_t ctas_per_group, uint32_t cta_in_group, uint32_t tx, uint32_t iter, - uint32_t& out_local_gt_count) { + uint32_t& out_local_gt_count, uint32_t& out_local_eq_count) { constexpr uint32_t RADIX = 256; constexpr uint32_t RADIX_BITS = 8; constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8; @@ -492,13 +590,7 @@ __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory( // Initial barrier (skip for single CTA) if constexpr (!SINGLE_CTA) { - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - int target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; - __syncthreads(); + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); // CTA 0 clears output counter AFTER barrier if (cta_in_group == 0 && tx == 0) { @@ -554,13 +646,7 @@ __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory( next_hist[i] = 0; } } - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - int target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; - __syncthreads(); + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { suffix_sum[i] = current_hist[i]; @@ -602,25 +688,38 @@ __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory( OrderedType ordered_pivot = static_cast(prefix_cache); - // Count > pivot elements by scanning shared_ordered + // Count > pivot (and optionally == pivot) elements by scanning shared_ordered. // This is needed because suffix_sum only tracks elements matching the current prefix, // not all elements > pivot (which includes elements with higher-order bits > pivot) if (tx == 0) { suffix_sum[0] = 0; + if constexpr (TRACK_EQ_COUNT) { + suffix_sum[1] = 0; + } } __syncthreads(); uint32_t my_gt_count = 0; + uint32_t my_eq_count = 0; #pragma unroll 2 for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { - if (shared_ordered[i] > ordered_pivot) { + const OrderedType ordered = shared_ordered[i]; + if (ordered > ordered_pivot) { my_gt_count++; } + if constexpr (TRACK_EQ_COUNT) { + if (ordered == ordered_pivot) { + my_eq_count++; + } + } } // Warp-level reduction for (int offset = 16; offset > 0; offset /= 2) { my_gt_count += __shfl_down_sync(0xffffffff, my_gt_count, offset); + if constexpr (TRACK_EQ_COUNT) { + my_eq_count += __shfl_down_sync(0xffffffff, my_eq_count, offset); + } } // First thread of each warp atomics to shared @@ -628,9 +727,19 @@ __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory( if (lane == 0 && my_gt_count > 0) { atomicAdd(&suffix_sum[0], my_gt_count); } + if constexpr (TRACK_EQ_COUNT) { + if (lane == 0 && my_eq_count > 0) { + atomicAdd(&suffix_sum[1], my_eq_count); + } + } __syncthreads(); out_local_gt_count = suffix_sum[0]; + if constexpr (TRACK_EQ_COUNT) { + out_local_eq_count = suffix_sum[1]; + } else { + out_local_eq_count = 0; + } #undef prefix_cache #undef remaining_k_cache @@ -642,53 +751,30 @@ __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory( } /*! - * \brief Find the k-th largest element pivot using radix select. - * - * This is the main entry point for the radix select algorithm. - * It performs NUM_ROUNDS of radix select to find the exact pivot value. - * - * \tparam BLOCK_THREADS Number of threads per block - * \tparam VEC_SIZE Vector size for memory access - * \tparam SINGLE_CTA True if single-CTA mode - * \tparam DType Data type (float, half, nv_bfloat16) + * \brief Load one CTA chunk into ordered shared memory, then find the pivot with radix select. * - * \param input Input data pointer (for this row) - * \param shared_ordered Shared memory for ordered values - * \param local_histogram Shared memory for local histogram - * \param suffix_sum Shared memory for suffix sum - * \param shared_scalars Shared memory for temporary scalar values (size >= 5) - * \param state RadixRowState pointer (nullptr if SINGLE_CTA) - * \param chunk_start Start index in vocab for this CTA - * \param actual_chunk_size Number of elements in this chunk - * \param k Number of top elements to select - * \param barrier_phase Reference to barrier phase counter - * \param ctas_per_group Number of CTAs per group - * \param cta_in_group CTA index within group - * \param tx Thread index - * \param iter Current iteration (for triple-buffer indexing) - * \return The pivot value (k-th largest element) + * This helper centralizes the shared-memory load and the exact k-th-element radix + * select. It returns the pivot in ordered representation. Callers can optionally request the + * CTA-local counts of elements + * `> pivot` and `== pivot`, which are needed by deterministic collect paths. */ -template -__device__ __forceinline__ DType RadixSelectFindPivot( +template +__device__ __forceinline__ typename RadixTopKTraits::OrderedType RadixSelectFindPivot( const DType* input, typename RadixTopKTraits::OrderedType* shared_ordered, uint32_t* local_histogram, uint32_t* suffix_sum, uint32_t* shared_scalars, RadixRowState* state, uint32_t chunk_start, uint32_t actual_chunk_size, uint32_t k, int& barrier_phase, - uint32_t ctas_per_group, uint32_t cta_in_group, uint32_t tx, uint32_t iter = 0) { + uint32_t ctas_per_group, uint32_t cta_in_group, uint32_t tx, uint32_t iter, + uint32_t& out_local_gt_count, uint32_t& out_local_eq_count) { using Traits = RadixTopKTraits; using OrderedType = typename Traits::OrderedType; - // Stage 1: Load and convert to ordered representation LoadToSharedOrdered(input, shared_ordered, chunk_start, actual_chunk_size, tx); - - // Stage 2: Radix select to find pivot - uint32_t local_gt_count = 0; // Not used in this function - OrderedType ordered_pivot = RadixSelectFromSharedMemory( + return RadixSelectFromSharedMemory( shared_ordered, actual_chunk_size, k, local_histogram, suffix_sum, shared_scalars, state, - barrier_phase, ctas_per_group, cta_in_group, tx, iter, local_gt_count); - - // Convert ordered representation back to DType pivot - return Traits::FromOrdered(ordered_pivot); + barrier_phase, ctas_per_group, cta_in_group, tx, iter, out_local_gt_count, + out_local_eq_count); } /*! @@ -764,14 +850,10 @@ __device__ __forceinline__ void RadixCollectIndices( // This is critical: without this barrier, CTAs may write == pivot elements while // other CTAs are still writing > pivot elements, causing incorrect positions. if constexpr (!SINGLE_CTA) { - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - int target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); + } else { + __syncthreads(); } - __syncthreads(); // Pass 2: Write elements == pivot // Use global atomic directly since we need cross-CTA coordination to respect @@ -796,6 +878,178 @@ __device__ __forceinline__ void RadixCollectIndices( #undef global_base_gt } +struct DeterministicCollectCountPair { + uint32_t gt; + uint32_t eq; +}; + +struct DeterministicCollectCountPairSum { + __device__ __forceinline__ DeterministicCollectCountPair operator()( + const DeterministicCollectCountPair& lhs, const DeterministicCollectCountPair& rhs) const { + return {lhs.gt + rhs.gt, lhs.eq + rhs.eq}; + } +}; + +/*! + * \brief Collect top-k indices with deterministic cross-CTA ordering. + * + * This variant preserves repeatable output by replacing cross-CTA atomic tie + * claiming with a fixed allocation scheme: + * - All > pivot elements are assigned output ranges in CTA order. + * - == pivot elements are then assigned deterministic prefixes from + * per-CTA gt/eq counts stored in \p det_scratch. + * + * Single-CTA mode degenerates to a block-local deterministic collect without + * using \p det_scratch. + * + * \tparam BLOCK_THREADS Number of threads per block + * \tparam SINGLE_CTA True if single-CTA mode + * \tparam OrderedType The ordered integer type + * \tparam OutputFunc Functor type: void(uint32_t original_idx, OrderedType ordered_val, int + * output_pos) + * + * \param shared_ordered Shared memory containing ordered values + * \param actual_chunk_size Number of elements in this CTA's chunk + * \param chunk_start Start index in input for this chunk + * \param k Number of top elements to select + * \param ordered_pivot The pivot value in ordered representation + * \param cta_local_gt_count Number of > pivot elements in this CTA (from radix select) + * \param cta_local_eq_count Number of == pivot elements in this CTA (from radix select) + * \param local_histogram Shared memory scratch reused for deterministic collect state + * \param state RadixRowState pointer for multi-CTA sync (nullptr if SINGLE_CTA) + * \param det_scratch Per-group scratch for multi-CTA gt/eq counts (nullptr if SINGLE_CTA) + * \param barrier_phase Reference to barrier phase counter + * \param ctas_per_group Number of CTAs per group + * \param cta_in_group CTA index within the current group + * \param tx Thread index + * \param output_func Functor called as output_func(original_idx, ordered_val, output_pos) for each + * selected element + */ +template +__device__ __forceinline__ void RadixCollectIndicesDeterministic( + const OrderedType* shared_ordered, uint32_t actual_chunk_size, uint32_t chunk_start, uint32_t k, + OrderedType ordered_pivot, uint32_t cta_local_gt_count, uint32_t cta_local_eq_count, + uint32_t* local_histogram, RadixRowState* state, RadixDeterministicCollectScratch* det_scratch, + int& barrier_phase, uint32_t ctas_per_group, uint32_t cta_in_group, uint32_t tx, + OutputFunc output_func) { +// Use local_histogram for counters: +// [0]: s_cta_local_gt_prefix - total >pivot count from earlier CTAs +// [1]: s_cta_local_eq_prefix - total ==pivot count from earlier CTAs +// [2]: s_row_total_gt_count - row-wide >pivot count across all CTAs +// [3]: s_row_eq_needed - number of ==pivot entries still needed after >pivot writes +// [4]: s_cta_local_eq_take - this CTA's assigned ==pivot quota +#define s_cta_local_gt_prefix local_histogram[0] +#define s_cta_local_eq_prefix local_histogram[1] +#define s_row_total_gt_count local_histogram[2] +#define s_row_eq_needed local_histogram[3] +#define s_cta_local_eq_take local_histogram[4] + uint32_t cta_local_eq_emit_limit = 0; + uint32_t cta_local_eq_output_base = 0; + if constexpr (SINGLE_CTA) { + if (tx == 0) { + s_cta_local_gt_prefix = 0; + s_cta_local_eq_prefix = 0; + s_row_total_gt_count = cta_local_gt_count; + s_row_eq_needed = (k > cta_local_gt_count) ? (k - cta_local_gt_count) : 0; + s_cta_local_eq_take = 0; + } + __syncthreads(); + // Single-CTA: keep the full ==pivot suffix contiguous after all >pivot entries. + cta_local_eq_emit_limit = s_row_eq_needed; + cta_local_eq_output_base = s_row_total_gt_count; + } else { + // Each CTA writes its local >pivot / ==pivot counts + if (tx == 0) { + s_cta_local_eq_prefix = 0; + s_cta_local_eq_take = 0; + det_scratch->gt_count[cta_in_group] = cta_local_gt_count; + det_scratch->eq_count[cta_in_group] = cta_local_eq_count; + } + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); + // Each CTA reads all >pivot / ==pivot counts + if (tx == 0) { + uint32_t cta_local_gt_prefix_accum = 0; + uint32_t row_total_gt = 0; + uint32_t cta_local_eq_prefix_accum = 0; + for (uint32_t c = 0; c < ctas_per_group; ++c) { + const uint32_t c_gt = det_scratch->gt_count[c]; + const uint32_t c_eq = det_scratch->eq_count[c]; + if (c < cta_in_group) { + cta_local_gt_prefix_accum += c_gt; + cta_local_eq_prefix_accum += c_eq; + } + row_total_gt += c_gt; + } + s_cta_local_gt_prefix = cta_local_gt_prefix_accum; + s_row_total_gt_count = row_total_gt; + s_row_eq_needed = (k > row_total_gt) ? (k - row_total_gt) : 0; + s_cta_local_eq_prefix = cta_local_eq_prefix_accum; + s_cta_local_eq_take = 0; + if (s_row_eq_needed > cta_local_eq_prefix_accum) { + s_cta_local_eq_take = min(cta_local_eq_count, s_row_eq_needed - cta_local_eq_prefix_accum); + } + } + __syncthreads(); + // Multi-CTA: only emit this CTA's assigned ==pivot quota at its deterministic output base. + cta_local_eq_emit_limit = s_cta_local_eq_take; + cta_local_eq_output_base = s_row_total_gt_count + s_cta_local_eq_prefix; + } + const uint32_t cta_local_gt_output_base = s_cta_local_gt_prefix; + const uint32_t cta_local_gt_emit_limit = + (k > cta_local_gt_output_base) ? (k - cta_local_gt_output_base) : 0; + +#undef s_cta_local_gt_prefix +#undef s_cta_local_eq_prefix +#undef s_row_total_gt_count +#undef s_row_eq_needed +#undef s_cta_local_eq_take + + using ScalarBlockScan = cub::BlockScan; + using PairBlockScan = + cub::BlockScan; + union DeterministicCollectScanTempStorage { + typename ScalarBlockScan::TempStorage scalar; + typename PairBlockScan::TempStorage pair; + }; + __shared__ DeterministicCollectScanTempStorage scan_temp_storage; + + if (cta_local_eq_emit_limit == 0) { // gt-only collect + DeterministicThreadStridedCollect( + tx, actual_chunk_size, scan_temp_storage.scalar, + [&](uint32_t i) { return shared_ordered[i] > ordered_pivot; }, cta_local_gt_emit_limit, + [&](uint32_t i, uint32_t local_pos) { + output_func(chunk_start + i, shared_ordered[i], cta_local_gt_output_base + local_pos); + }); + return; + } + + // Collect gt and eq elements + DeterministicCollectCountPair thread_local_counts = {0, 0}; + for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { + const OrderedType ordered = shared_ordered[i]; + thread_local_counts.gt += static_cast(ordered > ordered_pivot); + thread_local_counts.eq += static_cast(ordered == ordered_pivot); + } + + DeterministicCollectCountPair thread_local_prefix = {0, 0}; + PairBlockScan(scan_temp_storage.pair) + .ExclusiveScan(thread_local_counts, thread_local_prefix, DeterministicCollectCountPair{0, 0}, + DeterministicCollectCountPairSum{}); + + DeterministicCollectCountPair thread_local_pos = thread_local_prefix; + for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { + const OrderedType ordered = shared_ordered[i]; + if (ordered > ordered_pivot && thread_local_pos.gt < cta_local_gt_emit_limit) { + output_func(chunk_start + i, ordered, cta_local_gt_output_base + thread_local_pos.gt); + ++thread_local_pos.gt; + } else if (ordered == ordered_pivot && thread_local_pos.eq < cta_local_eq_emit_limit) { + output_func(chunk_start + i, ordered, cta_local_eq_output_base + thread_local_pos.eq); + ++thread_local_pos.eq; + } + } + __syncthreads(); +} + // ==================== Unified Radix Top-K Kernel with Epilogue Modes ==================== /*! @@ -818,12 +1072,13 @@ enum class RadixTopKMode { * \tparam BLOCK_THREADS Number of threads per block * \tparam VEC_SIZE Vector size for memory access * \tparam SINGLE_CTA True if single-CTA mode + * \tparam DETERMINISTIC True to use deterministic collect path * \tparam MODE Epilogue mode (Basic, PageTableTransform, or RaggedTransform) * \tparam DType Data type (float, half, nv_bfloat16) * \tparam IdType Index type */ -template +template __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( DType* input, // [num_rows, stride] IdType* output_indices, // [num_rows, top_k] - indices or page table entries @@ -834,10 +1089,9 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( const IdType* row_to_batch, // [num_rows] batch mapping for PageTable, nullptr otherwise int64_t aux_stride, // src_page_table stride for PageTable mode, 0 otherwise uint32_t top_k_val, uint32_t stride, uint32_t num_rows, RadixRowState* row_states, - uint32_t chunk_size, uint32_t ctas_per_group) { + RadixDeterministicCollectScratch* det_scratches, uint32_t chunk_size, uint32_t ctas_per_group) { using Traits = RadixTopKTraits; using OrderedType = typename Traits::OrderedType; - constexpr uint32_t RADIX = 256; const uint32_t global_cta_id = blockIdx.x; @@ -862,7 +1116,10 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( if constexpr (!SINGLE_CTA) { state = &row_states[group_id]; } - + RadixDeterministicCollectScratch* det_scratch = nullptr; + if constexpr (!SINGLE_CTA && DETERMINISTIC) { + det_scratch = &det_scratches[group_id]; + } uint32_t num_groups = gridDim.x / ctas_per_group; uint32_t total_iterations = (num_rows + num_groups - 1) / num_groups; @@ -955,37 +1212,44 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( const uint32_t chunk_end = min(chunk_start + chunk_size, length); const uint32_t actual_chunk_size = ((chunk_start < length) ? (chunk_end - chunk_start) : 0); - // Stage 1: Load and convert to ordered representation - LoadToSharedOrdered( - input + row_idx * stride, shared_ordered, chunk_start, actual_chunk_size, tx); - - // Stage 2: Radix select to find k-th largest element (also computes local_gt_count) - uint32_t local_gt_count = 0; - OrderedType ordered_pivot = RadixSelectFromSharedMemory( - shared_ordered, actual_chunk_size, k, local_histogram, suffix_sum, shared_scalars, state, - barrier_phase, ctas_per_group, cta_in_group, tx, iter, local_gt_count); + // Stage 1: Load the chunk into shared memory, then radix-select the pivot. + uint32_t cta_local_gt_count = 0; + uint32_t cta_local_eq_count = 0; + OrderedType ordered_pivot = + RadixSelectFindPivot( + input + row_idx * stride, shared_ordered, local_histogram, suffix_sum, shared_scalars, + state, chunk_start, actual_chunk_size, k, barrier_phase, ctas_per_group, cta_in_group, + tx, iter, cta_local_gt_count, cta_local_eq_count); + + auto collect_indices = [&](auto&& output_func) { + if constexpr (DETERMINISTIC) { + RadixCollectIndicesDeterministic( + shared_ordered, actual_chunk_size, chunk_start, k, ordered_pivot, cta_local_gt_count, + cta_local_eq_count, local_histogram, state, det_scratch, barrier_phase, ctas_per_group, + cta_in_group, tx, output_func); + } else { + RadixCollectIndices( + shared_ordered, actual_chunk_size, chunk_start, k, ordered_pivot, cta_local_gt_count, + local_histogram, &shared_output_counter, state, barrier_phase, ctas_per_group, tx, + output_func); + } + }; - // Stage 3: Collect indices with mode-specific epilogue (single pass) + // Stage 2: Collect indices with mode-specific epilogue (single pass) if constexpr (MODE == RadixTopKMode::Basic) { DType* row_output_values = output_values + row_idx * top_k_val; - RadixCollectIndices( - shared_ordered, actual_chunk_size, chunk_start, k, ordered_pivot, local_gt_count, - local_histogram, &shared_output_counter, state, barrier_phase, ctas_per_group, tx, - [&](uint32_t original_idx, OrderedType ordered_val, int pos) { - row_output[pos] = static_cast(original_idx); - row_output_values[pos] = Traits::FromOrdered(ordered_val); - }); + collect_indices([&](uint32_t original_idx, OrderedType ordered_val, int pos) { + row_output[pos] = static_cast(original_idx); + row_output_values[pos] = Traits::FromOrdered(ordered_val); + }); } else if constexpr (MODE == RadixTopKMode::PageTableTransform) { uint32_t batch_idx = (row_to_batch != nullptr) ? row_to_batch[row_idx] : row_idx; const IdType* src_page_entry = aux_data + batch_idx * aux_stride; // Collect raw indices first - RadixCollectIndices( - shared_ordered, actual_chunk_size, chunk_start, k, ordered_pivot, local_gt_count, - local_histogram, &shared_output_counter, state, barrier_phase, ctas_per_group, tx, - [&](uint32_t original_idx, OrderedType /*ordered_val*/, int pos) { - row_output[pos] = static_cast(original_idx); - }); + collect_indices([&](uint32_t original_idx, OrderedType /*ordered_val*/, int pos) { + row_output[pos] = static_cast(original_idx); + }); if constexpr (SINGLE_CTA) { __syncthreads(); @@ -996,13 +1260,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( } } else { // Barrier to ensure all CTAs finished writing indices - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - int target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; - __syncthreads(); + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); // All CTAs participate in page table transform (coalesced access) uint32_t elems_per_cta = (k + ctas_per_group - 1) / ctas_per_group; @@ -1015,16 +1273,13 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( } } else { // RaggedTransform IdType offset = aux_data[row_idx]; - RadixCollectIndices( - shared_ordered, actual_chunk_size, chunk_start, k, ordered_pivot, local_gt_count, - local_histogram, &shared_output_counter, state, barrier_phase, ctas_per_group, tx, - [&](uint32_t original_idx, OrderedType /*ordered_val*/, int pos) { - row_output[pos] = static_cast(original_idx) + offset; - }); + collect_indices([&](uint32_t original_idx, OrderedType /*ordered_val*/, int pos) { + row_output[pos] = static_cast(original_idx) + offset; + }); } } - // Clear histogram buffers and reset arrival counter + // Clear histogram buffers and reset arrival counter for next kernel launch (only for multi-CTA) if constexpr (!SINGLE_CTA) { if (cta_in_group == 0) { for (uint32_t buf = 0; buf < 3; ++buf) { @@ -1032,6 +1287,14 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( state->histogram[buf][i] = 0; } } + if constexpr (DETERMINISTIC) { + static_assert(sizeof(RadixDeterministicCollectScratch) % sizeof(uint32_t) == 0); + uint32_t* det_words = reinterpret_cast(det_scratch); + constexpr uint32_t DET_WORDS = sizeof(RadixDeterministicCollectScratch) / sizeof(uint32_t); + for (uint32_t i = tx; i < DET_WORDS; i += BLOCK_THREADS) { + det_words[i] = 0; + } + } if (tx == 0) { st_release(&state->arrival_counter, 0); } @@ -1133,19 +1396,18 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKMaskLogitsKernel_Multi continue; } - // ========== Stage 1: Load and convert to ordered representation ========== - LoadToSharedOrdered( - logits + row_idx * vocab_size, shared_ordered, chunk_start, actual_chunk_size, tx); - - // ========== Stage 2: Radix select to find pivot ========== + // Stage 1: Load the chunk into shared memory, then radix-select the pivot. uint32_t local_gt_count = 0; // Not used in this kernel - OrderedType ordered_pivot = RadixSelectFromSharedMemory( - shared_ordered, actual_chunk_size, k, local_histogram, suffix_sum, shared_scalars, state, - barrier_phase, ctas_per_group, cta_in_group, tx, iter, local_gt_count); + uint32_t local_eq_count = 0; // Not used in this kernel + OrderedType ordered_pivot = + RadixSelectFindPivot( + logits + row_idx * vocab_size, shared_ordered, local_histogram, suffix_sum, + shared_scalars, state, chunk_start, actual_chunk_size, k, barrier_phase, ctas_per_group, + cta_in_group, tx, iter, local_gt_count, local_eq_count); pivot = Traits::FromOrdered(ordered_pivot); - // ========== Stage 3: Final masking pass ========== + // Stage 2: Final masking pass const DType neg_inf = Traits::NegInf(); const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; vec_t logits_vec; @@ -1207,7 +1469,11 @@ cudaError_t RadixTopKMaskLogitsMultiCTA(DType* logits, DType* masked_logits, IdT constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); // Calculate max chunk size that fits in shared memory - const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; + const size_t available_for_ordered = GetRadixTopKAvailableOrderedSmemBytes( + max_smem_per_block, fixed_smem_aligned, false); + if (available_for_ordered == 0) { + return cudaErrorInvalidValue; + } uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = round_down(max_chunk_elements, vec_size); const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; @@ -1261,7 +1527,7 @@ cudaError_t RadixTopKMaskLogitsMultiCTA(DType* logits, DType* masked_logits, IdT * \brief Multi-CTA Radix Top-K RenormProb kernel with unified single/multi-CTA paths. * * Finds the k-th largest probability, then normalizes all probs >= pivot to sum to 1, - * setting all others to 0. Uses the shared RadixSelectFindPivot function. + * setting all others to 0. Reuses the shared load+radix-select helper. */ template @@ -1277,7 +1543,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_Multi using Traits = RadixTopKTraits; using OrderedType = typename Traits::OrderedType; - constexpr uint32_t RADIX = 256; + constexpr uint32_t RADIX = 256; // 8-bit radix const uint32_t global_cta_id = blockIdx.x; const uint32_t group_id = global_cta_id / ctas_per_group; @@ -1361,27 +1627,14 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_Multi } } // Barrier for initialization - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - int target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; - __syncthreads(); + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); if (tx == 0 && block_sum > 0) { atomicAdd(&state->sum_topk, block_sum); } // Barrier to ensure all CTAs have contributed - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; - __syncthreads(); - + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); normalizer = math::ptx_rcp(max(state->sum_topk, 1e-8f)); } else { // Single-CTA: use block_sum directly @@ -1423,11 +1676,14 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_Multi continue; } - // ========== Stage 1: Find pivot using RadixSelectFindPivot ========== - pivot = RadixSelectFindPivot( + // ========== Stage 1: Find pivot ========== + uint32_t local_gt_count = 0; // Not used in this kernel + uint32_t local_eq_count = 0; // Not used in this kernel + auto ordered_pivot = RadixSelectFindPivot( probs + row_idx * vocab_size, shared_ordered, local_histogram, suffix_sum, shared_scalars, state, chunk_start, actual_chunk_size, k, barrier_phase, ctas_per_group, cta_in_group, tx, - iter); + iter, local_gt_count, local_eq_count); + pivot = Traits::FromOrdered(ordered_pivot); // ========== Stage 2: Compute sum of elements >= pivot ========== float thread_sum = 0.0f; @@ -1466,27 +1722,14 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_Multi } } // Barrier for initialization - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - int target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; - __syncthreads(); + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); if (tx == 0 && block_sum > 0) { atomicAdd(&state->sum_topk, block_sum); } // Barrier to ensure all CTAs have contributed - if (tx == 0) { - red_release(&state->arrival_counter, 1); - } - target = (barrier_phase + 1) * ctas_per_group; - wait_ge(&state->arrival_counter, target, tx); - barrier_phase++; - __syncthreads(); - + AdvanceRadixGroupBarrier(state, barrier_phase, ctas_per_group, tx); normalizer = math::ptx_rcp(max(state->sum_topk, 1e-8f)); } else { // Single-CTA: use block_sum directly @@ -1555,7 +1798,11 @@ cudaError_t RadixTopKRenormProbMultiCTA(DType* probs, DType* renormed_prob, IdTy constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); // Calculate max chunk size that fits in shared memory - const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; + const size_t available_for_ordered = GetRadixTopKAvailableOrderedSmemBytes( + max_smem_per_block, fixed_smem_aligned, false); + if (available_for_ordered == 0) { + return cudaErrorInvalidValue; + } uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = round_down(max_chunk_elements, vec_size); const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; @@ -1627,7 +1874,7 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(DType* input, IdType* output_pag const IdType* row_to_batch, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, RadixRowState* row_states_buffer, - cudaStream_t stream = 0) { + bool deterministic, cudaStream_t stream = 0) { using OrderedType = typename RadixTopKTraits::OrderedType; constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), max_len); @@ -1642,14 +1889,21 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(DType* input, IdType* output_pag constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 5); constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); + const size_t available_for_ordered = GetRadixTopKAvailableOrderedSmemBytes( + max_smem_per_block, fixed_smem_aligned, deterministic); + if (available_for_ordered == 0) { + return cudaErrorInvalidValue; + } - const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = round_down(max_chunk_elements, vec_size); const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); uint32_t ctas_per_group = ceil_div(max_len, max_chunk_elements); + if (deterministic && ctas_per_group > RADIX_TOPK_MAX_DETERMINISTIC_CTAS_PER_GROUP) { + return cudaErrorInvalidConfiguration; + } uint32_t chunk_size = ceil_div(max_len, ctas_per_group); chunk_size = round_up(chunk_size, vec_size); chunk_size = std::min(chunk_size, max_chunk_elements); @@ -1660,38 +1914,46 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(DType* input, IdType* output_pag uint32_t num_groups = std::min(static_cast(num_sms) / ctas_per_group, num_rows); if (num_groups == 0) num_groups = 1; uint32_t total_ctas = num_groups * ctas_per_group; + RadixDeterministicCollectScratch* det_scratch_buffer = + MaybeGetRadixDeterministicCollectScratchBuffer(row_states_buffer, num_groups, single_cta, + deterministic); // Unified kernel parameters DType* output_values = nullptr; // Not used in PageTableTransform mode + dim3 nblks(total_ctas); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&input, &output_page_table, &output_values, &src_page_table, + &lengths, &row_to_batch, &src_stride, &top_k_val, + &max_len, &num_rows, &row_states_buffer, &det_scratch_buffer, + &chunk_size, &ctas_per_group}; + +#define LAUNCH_PAGE_TABLE_KERNEL(THREADS, SINGLE_CTA_FLAG, DET_FLAG) \ + do { \ + auto kernel = RadixTopKKernel_Unified; \ + FLASHINFER_CUDA_CALL( \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); \ + } while (0) DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { if (single_cta) { - auto kernel = RadixTopKKernel_Unified; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - dim3 nblks(total_ctas); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_page_table, &output_values, &src_page_table, - &lengths, &row_to_batch, &src_stride, &top_k_val, - &max_len, &num_rows, &row_states_buffer, &chunk_size, - &ctas_per_group}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if (!deterministic) { + LAUNCH_PAGE_TABLE_KERNEL(BLOCK_THREADS, true, false); + } else { + LAUNCH_PAGE_TABLE_KERNEL(BLOCK_THREADS, true, true); + } } else { - auto kernel = RadixTopKKernel_Unified; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - dim3 nblks(total_ctas); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_page_table, &output_values, &src_page_table, - &lengths, &row_to_batch, &src_stride, &top_k_val, - &max_len, &num_rows, &row_states_buffer, &chunk_size, - &ctas_per_group}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if (!deterministic) { + LAUNCH_PAGE_TABLE_KERNEL(BLOCK_THREADS, false, false); + } else { + LAUNCH_PAGE_TABLE_KERNEL(BLOCK_THREADS, false, true); + } } }); +#undef LAUNCH_PAGE_TABLE_KERNEL + return cudaSuccess; } @@ -1716,7 +1978,7 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input, IdType* output_indice const IdType* offsets, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, RadixRowState* row_states_buffer, - cudaStream_t stream = 0) { + bool deterministic, cudaStream_t stream = 0) { using OrderedType = typename RadixTopKTraits::OrderedType; constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), max_len); @@ -1731,14 +1993,21 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input, IdType* output_indice constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 5); constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); + const size_t available_for_ordered = GetRadixTopKAvailableOrderedSmemBytes( + max_smem_per_block, fixed_smem_aligned, deterministic); + if (available_for_ordered == 0) { + return cudaErrorInvalidValue; + } - const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = round_down(max_chunk_elements, vec_size); const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); uint32_t ctas_per_group = ceil_div(max_len, max_chunk_elements); + if (deterministic && ctas_per_group > RADIX_TOPK_MAX_DETERMINISTIC_CTAS_PER_GROUP) { + return cudaErrorInvalidConfiguration; + } uint32_t chunk_size = ceil_div(max_len, ctas_per_group); chunk_size = round_up(chunk_size, vec_size); chunk_size = std::min(chunk_size, max_chunk_elements); @@ -1749,40 +2018,48 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input, IdType* output_indice uint32_t num_groups = std::min(static_cast(num_sms) / ctas_per_group, num_rows); if (num_groups == 0) num_groups = 1; uint32_t total_ctas = num_groups * ctas_per_group; + RadixDeterministicCollectScratch* det_scratch_buffer = + MaybeGetRadixDeterministicCollectScratchBuffer(row_states_buffer, num_groups, single_cta, + deterministic); // Unified kernel parameters DType* output_values = nullptr; // Not used in RaggedTransform mode const IdType* row_to_batch = nullptr; // Not used in RaggedTransform mode int64_t aux_stride = 0; // Not used in RaggedTransform mode + dim3 nblks(total_ctas); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&input, &output_indices, &output_values, &offsets, + &lengths, &row_to_batch, &aux_stride, &top_k_val, + &max_len, &num_rows, &row_states_buffer, &det_scratch_buffer, + &chunk_size, &ctas_per_group}; + +#define LAUNCH_RAGGED_KERNEL(THREADS, SINGLE_CTA_FLAG, DET_FLAG) \ + do { \ + auto kernel = RadixTopKKernel_Unified; \ + FLASHINFER_CUDA_CALL( \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); \ + } while (0) DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { if (single_cta) { - auto kernel = RadixTopKKernel_Unified; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - dim3 nblks(total_ctas); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_indices, &output_values, &offsets, - &lengths, &row_to_batch, &aux_stride, &top_k_val, - &max_len, &num_rows, &row_states_buffer, &chunk_size, - &ctas_per_group}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if (!deterministic) { + LAUNCH_RAGGED_KERNEL(BLOCK_THREADS, true, false); + } else { + LAUNCH_RAGGED_KERNEL(BLOCK_THREADS, true, true); + } } else { - auto kernel = RadixTopKKernel_Unified; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - dim3 nblks(total_ctas); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_indices, &output_values, &offsets, - &lengths, &row_to_batch, &aux_stride, &top_k_val, - &max_len, &num_rows, &row_states_buffer, &chunk_size, - &ctas_per_group}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if (!deterministic) { + LAUNCH_RAGGED_KERNEL(BLOCK_THREADS, false, false); + } else { + LAUNCH_RAGGED_KERNEL(BLOCK_THREADS, false, true); + } } }); +#undef LAUNCH_RAGGED_KERNEL + return cudaSuccess; } @@ -1803,7 +2080,7 @@ template cudaError_t RadixTopKMultiCTA(DType* input, IdType* output_indices, DType* output_values, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t vocab_size, RadixRowState* row_states_buffer, - cudaStream_t stream = 0) { + bool deterministic, cudaStream_t stream = 0) { using OrderedType = typename RadixTopKTraits::OrderedType; constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size); @@ -1820,14 +2097,21 @@ cudaError_t RadixTopKMultiCTA(DType* input, IdType* output_indices, DType* outpu // Scalars: 5 for single-CTA, 4 for multi-CTA constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 5); constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); + const size_t available_for_ordered = GetRadixTopKAvailableOrderedSmemBytes( + max_smem_per_block, fixed_smem_aligned, deterministic); + if (available_for_ordered == 0) { + return cudaErrorInvalidValue; + } - const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = round_down(max_chunk_elements, vec_size); const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); + if (deterministic && ctas_per_group > RADIX_TOPK_MAX_DETERMINISTIC_CTAS_PER_GROUP) { + return cudaErrorInvalidConfiguration; + } uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); chunk_size = round_up(chunk_size, vec_size); chunk_size = std::min(chunk_size, max_chunk_elements); @@ -1842,40 +2126,48 @@ cudaError_t RadixTopKMultiCTA(DType* input, IdType* output_indices, DType* outpu uint32_t num_groups = std::min(static_cast(num_sms) / ctas_per_group, batch_size); if (num_groups == 0) num_groups = 1; uint32_t total_ctas = num_groups * ctas_per_group; + RadixDeterministicCollectScratch* det_scratch_buffer = + MaybeGetRadixDeterministicCollectScratchBuffer(row_states_buffer, num_groups, single_cta, + deterministic); // Unified kernel parameters IdType* lengths = nullptr; // Not used in Basic mode const IdType* row_to_batch = nullptr; // Not used in Basic mode int64_t aux_stride = 0; // Not used in Basic mode + dim3 nblks(total_ctas); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&input, &output_indices, &output_values, &top_k_arr, + &lengths, &row_to_batch, &aux_stride, &top_k_val, + &vocab_size, &batch_size, &row_states_buffer, &det_scratch_buffer, + &chunk_size, &ctas_per_group}; + +#define LAUNCH_BASIC_KERNEL(THREADS, SINGLE_CTA_FLAG, DET_FLAG) \ + do { \ + auto kernel = RadixTopKKernel_Unified; \ + FLASHINFER_CUDA_CALL( \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); \ + } while (0) DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { if (single_cta) { - auto kernel = RadixTopKKernel_Unified; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - dim3 nblks(total_ctas); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_indices, &output_values, &top_k_arr, - &lengths, &row_to_batch, &aux_stride, &top_k_val, - &vocab_size, &batch_size, &row_states_buffer, &chunk_size, - &ctas_per_group}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if (!deterministic) { + LAUNCH_BASIC_KERNEL(BLOCK_THREADS, true, false); + } else { + LAUNCH_BASIC_KERNEL(BLOCK_THREADS, true, true); + } } else { - auto kernel = RadixTopKKernel_Unified; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - dim3 nblks(total_ctas); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_indices, &output_values, &top_k_arr, - &lengths, &row_to_batch, &aux_stride, &top_k_val, - &vocab_size, &batch_size, &row_states_buffer, &chunk_size, - &ctas_per_group}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if (!deterministic) { + LAUNCH_BASIC_KERNEL(BLOCK_THREADS, false, false); + } else { + LAUNCH_BASIC_KERNEL(BLOCK_THREADS, false, true); + } } }); +#undef LAUNCH_BASIC_KERNEL + return cudaSuccess; } // ==================== FilteredTopK Implementation ==================== @@ -1972,7 +2264,7 @@ enum class FilteredTopKMode { Plain, PageTable, Ragged }; * - PageTable: output = dst_page_table, aux_input = src_page_table, aux_stride = src_stride * - Ragged: output = indices, aux_input = offsets, aux_output/aux_stride/row_to_batch unused */ -template +template __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) FilteredTopKUnifiedKernel(const DType* __restrict__ input, IdType* __restrict__ output, DType* __restrict__ aux_output, // values for Plain mode @@ -1984,6 +2276,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) constexpr uint32_t BLOCK_SIZE = FILTERED_TOPK_BLOCK_THREADS; constexpr int RADIX = 256; constexpr int SMEM_INPUT_SIZE = FILTERED_TOPK_SMEM_INPUT_SIZE; + static_assert(BLOCK_SIZE % 32 == 0, "BLOCK_SIZE must be a multiple of warp size"); const uint32_t bid = blockIdx.x; const int tx = threadIdx.x; @@ -2011,11 +2304,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) // Trivial case: length <= top_k if (length <= static_cast(top_k)) { for (int i = tx; i < static_cast(top_k); i += BLOCK_SIZE) { - if constexpr (MODE == FilteredTopKMode::PageTable) { - dst[i] = (i < length) ? src_page_entry[i] : static_cast(-1); - } else if constexpr (MODE == FilteredTopKMode::Ragged) { - dst[i] = (i < length) ? static_cast(i) + offset_val : static_cast(-1); - } else { // Plain + if constexpr (MODE == FilteredTopKMode::Plain) { if (i < length) { dst[i] = static_cast(i); dst_values[i] = score[i]; @@ -2023,6 +2312,13 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) dst[i] = static_cast(-1); dst_values[i] = DType(0); } + } else if constexpr (DETERMINISTIC) { + // In deterministic mode the page-table/ragged transform happens in SortTopKByIndexKernel + dst[i] = (i < length) ? static_cast(i) : static_cast(-1); + } else if constexpr (MODE == FilteredTopKMode::PageTable) { + dst[i] = (i < length) ? src_page_entry[i] : static_cast(-1); + } else { // Ragged + dst[i] = (i < length) ? static_cast(i) + offset_val : static_cast(-1); } } return; @@ -2030,10 +2326,13 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) // Static shared memory alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; + __shared__ int s_counter; + __shared__ int s_threshold_bin_id; + // Per-round copies of s_threshold_bin_id for deterministic pivot rebuild. + __shared__ int s_refine_thresholds[4]; + __shared__ int s_num_input[2]; alignas(128) __shared__ int s_indices[FILTERED_TOPK_MAX_K]; + // Set 1 when s_input_idx overflows in tie-heavy workload __shared__ int s_refine_overflow; __shared__ int s_last_remain; @@ -2043,13 +2342,20 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; using Traits = FilteredTopKTraits; + using OrderedType = typename Traits::OrderedType; int topk = top_k; if (tx == 0) s_refine_overflow = 0; - - // Stage 1: 8-bit coarse histogram with vectorized loads + if constexpr (DETERMINISTIC) { + if (tx < 4) { + s_refine_thresholds[tx] = 0xFF; + } + } if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); + // Stage 1: (shared by deterministic and non-deterministic modes) + // build a coarse histogram and identify the threshold bin. + // The modes diverge later when collecting == pivot elements. vec_t score_vec; const int aligned_length = (length / VEC_SIZE) * VEC_SIZE; @@ -2120,6 +2426,22 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) constexpr int NUM_ROUNDS = Traits::NUM_REFINE_ROUNDS; constexpr int FIRST_SHIFT = Traits::FIRST_REFINE_SHIFT; + // fp16/bf16: stop_round = 0; fp32: stop_round = 0,1,2,3 + auto build_det_pivot = [&](int stop_round) -> OrderedType { + if constexpr (sizeof(OrderedType) == 2) { + return static_cast((static_cast(threshold_bin) << 8) | + static_cast(s_refine_thresholds[0])); + } else { // fp32 + uint32_t pivot = 0; + for (int round = 0; round < NUM_ROUNDS; ++round) { + uint32_t byte = + (round <= stop_round) ? static_cast(s_refine_thresholds[round]) : 0xFFu; + pivot |= (byte << (FIRST_SHIFT - round * 8)); + } + return static_cast(pivot); + } + }; + if (topk == 0) { // Collect indices where bin > threshold auto collect_coarse_gt = [&](auto raw_input, int index) { @@ -2136,6 +2458,38 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); + // Both non-det and det modes use atomicAdd to append >threshold winners here; + // only ==threshold handling diverges between the two modes. + auto collect_gt_and_nondet_eq_threshold = [&](auto value, auto threshold, int idx, + bool collect_eq) { + if (value > threshold) { + const int pos = atomicAdd(&s_counter, 1); + s_indices[pos] = idx; + } else if constexpr (!DETERMINISTIC) { + if (collect_eq && value == threshold) { + const int pos = atomicAdd(&s_last_remain, -1); + if (pos > 0) { + s_indices[static_cast(top_k) - pos] = idx; + } + } + } + }; + + auto collect_det_eq_pivot = [&](OrderedType pivot, int eq_needed) { + if (eq_needed > 0) { + using DetCollectBlockScan = + cub::BlockScan; + __shared__ typename DetCollectBlockScan::TempStorage temp_storage; + DeterministicThreadStridedCollect( + tx, length, temp_storage, + [&](uint32_t idx) { return Traits::ToOrdered(score[idx]) == pivot; }, eq_needed, + [&](uint32_t idx, uint32_t local_pos) { + s_indices[static_cast(top_k) - eq_needed + static_cast(local_pos)] = + static_cast(idx); + }); + } + }; + // Filter + histogram for refinement auto filter_and_add_to_histogram = [&](auto raw_input, int index) { const auto bin = static_cast(Traits::ToCoarseKey(raw_input)); @@ -2167,15 +2521,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) const auto idx = s_input_idx[r_idx][i]; const auto raw_input = score[idx]; const auto bin = (Traits::ToOrdered(raw_input) >> offset) & 0xFF; - if (static_cast(bin) > threshold) { - const auto pos = atomicAdd(&s_counter, 1); - s_indices[pos] = idx; - } else if (static_cast(bin) == threshold) { - const auto pos = atomicAdd(&s_last_remain, -1); - if (pos > 0) { - s_indices[top_k - pos] = idx; - } - } + collect_gt_and_nondet_eq_threshold(static_cast(bin), threshold, idx, + /*allow_eq_claim=*/true); } __syncthreads(); }; @@ -2206,6 +2553,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) } __syncthreads(); }; + // Returns true if this round fully resolves the pivot, i.e. no ==threshold + // elements need to be carried into another refine round. auto run_refine_round = [&](int r_idx, int offset, auto is_last_round_tag) { constexpr bool IS_LAST_ROUND = decltype(is_last_round_tag)::value; const auto raw_num_input = s_num_input[r_idx]; @@ -2214,8 +2563,12 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) update_refine_threshold(r_idx ^ 1, std::true_type{}); const auto threshold = s_threshold_bin_id; + if constexpr (DETERMINISTIC) { + if (tx == 0) { + s_refine_thresholds[(FIRST_SHIFT - offset) / 8] = threshold; + } + } topk -= s_histogram[threshold + 1]; - if (topk == 0) { // Final round reached: only collect bins strictly greater than threshold. for (int i = tx; i < num_input; i += BLOCK_SIZE) { @@ -2237,7 +2590,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) } return false; }; - if constexpr (NUM_ROUNDS == 1) { + if constexpr (NUM_ROUNDS == 1) { // fast path for 1-round refine. if (s_refine_overflow) { if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); @@ -2273,30 +2626,32 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) return; } const auto sub_bin = Traits::ToOrdered(raw_input) & 0xFF; - if (static_cast(sub_bin) > threshold) { - const auto pos = atomicAdd(&s_counter, 1); - s_indices[pos] = index; - } else if (static_cast(sub_bin) == threshold) { - const auto pos = atomicAdd(&s_last_remain, -1); - if (pos > 0) { - s_indices[top_k - pos] = index; - } - } + collect_gt_and_nondet_eq_threshold(static_cast(sub_bin), threshold, index, + /*allow_eq_claim=*/true); }; for_each_score_full(collect_from_full_threshold_bin); __syncthreads(); + if constexpr (DETERMINISTIC) { + int eq_needed = s_last_remain; + collect_det_eq_pivot(static_cast((static_cast(threshold_bin) << 8) | + static_cast(threshold)), + eq_needed); + } } else { - // fast path for 1-round refine. const int round = 0; const auto r_idx = round % 2; const int offset = FIRST_SHIFT; run_refine_round(r_idx, offset, std::true_type{}); + if constexpr (DETERMINISTIC) { + collect_det_eq_pivot(build_det_pivot(/*stop_round=*/0), topk); + } } } else { // Multi-round refine path (float32): if any refine-buffer overflow is detected, // switch to a correctness-first full rebuild of the threshold-bin selection. // This fallback may be slower than the fast path, but avoids partial-state corruption. + int det_stop_round = NUM_ROUNDS - 1; if (!s_refine_overflow) { #pragma unroll for (int round = 0; round < NUM_ROUNDS; ++round) { @@ -2304,10 +2659,12 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) const int offset = FIRST_SHIFT - round * 8; if (round == NUM_ROUNDS - 1) { if (run_refine_round(r_idx, offset, std::true_type{})) { + det_stop_round = round; break; } } else { if (run_refine_round(r_idx, offset, std::false_type{})) { + det_stop_round = round; break; } } @@ -2316,10 +2673,14 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) } } } + if constexpr (DETERMINISTIC) { + if (!s_refine_overflow) { + collect_det_eq_pivot(build_det_pivot(det_stop_round), topk); + } + } // run_refine_round can set s_refine_overflow during the loop above, so this // check is intentionally separate from the first if (!s_refine_overflow). if (s_refine_overflow) { - using OrderedType = typename Traits::OrderedType; static_assert(sizeof(OrderedType) == 4, "Multi-round overflow fallback expects 32-bit ordered keys."); @@ -2404,26 +2765,21 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) auto collect_by_pivot = [&](auto raw_input, int index) { const auto coarse_bin = static_cast(Traits::ToCoarseKey(raw_input)); if (coarse_bin > threshold_bin) { - const auto pos = atomicAdd(&s_counter, 1); - s_indices[pos] = index; + collect_gt_and_nondet_eq_threshold(coarse_bin, threshold_bin, index, + /*allow_eq_claim=*/false); return; } if (coarse_bin != threshold_bin) { return; } const auto ordered = static_cast(Traits::ToOrdered(raw_input)); - if (ordered > pivot) { - const auto pos = atomicAdd(&s_counter, 1); - s_indices[pos] = index; - } else if (eq_needed > 0 && ordered == pivot) { - const auto pos = atomicAdd(&s_last_remain, -1); - if (pos > 0) { - s_indices[top_k - pos] = index; - } - } + collect_gt_and_nondet_eq_threshold(ordered, pivot, index, eq_needed > 0); }; for_each_score_full(collect_by_pivot); __syncthreads(); + if constexpr (DETERMINISTIC) { + collect_det_eq_pivot(static_cast(pivot), eq_needed); + } } } } @@ -2432,13 +2788,15 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) #pragma unroll 2 for (int base = tx; base < static_cast(top_k); base += BLOCK_SIZE) { const int idx = s_indices[base]; - if constexpr (MODE == FilteredTopKMode::PageTable) { - dst[base] = src_page_entry[idx]; - } else if constexpr (MODE == FilteredTopKMode::Ragged) { - dst[base] = static_cast(idx) + offset_val; - } else { // Plain + if constexpr (MODE == FilteredTopKMode::Plain) { dst[base] = static_cast(idx); dst_values[base] = score[idx]; + } else if constexpr (DETERMINISTIC) { // transform in SortTopKByIndexKernel + dst[base] = static_cast(idx); + } else if constexpr (MODE == FilteredTopKMode::PageTable) { + dst[base] = src_page_entry[idx]; + } else { // Ragged + dst[base] = static_cast(idx) + offset_val; } } } @@ -2463,105 +2821,248 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) { return static_cast(g); } -// Launch functions with VEC_SIZE dispatch - using unified kernel -template -cudaError_t FilteredTopKPageTableTransform(DType* input, IdType* output_page_table, - const IdType* src_page_table, int64_t src_stride, - const IdType* row_to_batch, IdType* lengths, - uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, - cudaStream_t stream = 0) { - constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC; - constexpr int MAX_VEC = 16 / sizeof(DType); +template +struct SortTopKByIndexBlockRadixSort; - dim3 grid(num_rows); - dim3 block(FILTERED_TOPK_BLOCK_THREADS); - DType* aux_output = nullptr; // Not used for PageTable mode - void* args[] = {&input, &output_page_table, &aux_output, &src_page_table, &src_stride, - &row_to_batch, &lengths, &num_rows, &top_k_val, &max_len}; +template +struct SortTopKByIndexBlockRadixSort { + using Type = cub::BlockRadixSort; +}; - const int vec_size = ComputeFilteredTopKVecSize(max_len); +template +struct SortTopKByIndexBlockRadixSort { + using Type = cub::BlockRadixSort; +}; -#define DISPATCH_VEC_SIZE(VS) \ - if (vec_size == VS) { \ - auto kernel = FilteredTopKUnifiedKernel; \ - FLASHINFER_CUDA_CALL( \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, smem_size, stream)); \ - return cudaSuccess; \ +template +__global__ void __launch_bounds__(BLOCK_THREADS) + SortTopKByIndexKernel(IdType* output_indices, DType* output_values, const IdType* aux_input, + int64_t aux_stride, const IdType* row_to_batch, uint32_t top_k, + uint32_t max_len) { + constexpr bool WITH_VALUES = (MODE == FilteredTopKMode::Plain); + using BlockRadixSortT = typename SortTopKByIndexBlockRadixSort::Type; + __shared__ typename BlockRadixSortT::TempStorage temp_storage; + + const uint32_t row = blockIdx.x; + const uint32_t tx = threadIdx.x; + IdType* row_output = output_indices + static_cast(row) * top_k; + + uint32_t keys[ITEMS_PER_THREAD]; + DType values[ITEMS_PER_THREAD]; + +#pragma unroll + for (uint32_t i = 0; i < ITEMS_PER_THREAD; ++i) { + uint32_t pos = tx * ITEMS_PER_THREAD + i; + if (pos < top_k) { + IdType idx = row_output[pos]; + keys[i] = (idx >= 0) ? static_cast(idx) : ~0u; + if constexpr (MODE == FilteredTopKMode::Plain) { + values[i] = output_values[static_cast(row) * top_k + pos]; + } + } else { + keys[i] = ~0u; + if constexpr (MODE == FilteredTopKMode::Plain) { + values[i] = DType(0); + } + } } - DISPATCH_VEC_SIZE(1) - DISPATCH_VEC_SIZE(2) - DISPATCH_VEC_SIZE(4) - if constexpr (MAX_VEC >= 8) { - DISPATCH_VEC_SIZE(8) + int end_bit = 32 - __clz(max_len); + if constexpr (MODE == FilteredTopKMode::Plain) { + BlockRadixSortT(temp_storage).Sort(keys, values, 0, end_bit); + } else { + BlockRadixSortT(temp_storage).Sort(keys, 0, end_bit); } -#undef DISPATCH_VEC_SIZE - return cudaSuccess; + const IdType* src_page_entry = nullptr; + IdType offset = 0; + if constexpr (MODE == FilteredTopKMode::PageTable) { + const uint32_t batch_idx = (row_to_batch != nullptr) ? row_to_batch[row] : row; + src_page_entry = aux_input + static_cast(batch_idx) * aux_stride; + } else if constexpr (MODE == FilteredTopKMode::Ragged) { + offset = aux_input[row]; + } + +#pragma unroll + for (uint32_t i = 0; i < ITEMS_PER_THREAD; ++i) { + uint32_t pos = tx * ITEMS_PER_THREAD + i; + if (pos < top_k) { + uint32_t idx = keys[i]; + if constexpr (MODE == FilteredTopKMode::Plain) { + row_output[pos] = static_cast(idx); + output_values[static_cast(row) * top_k + pos] = values[i]; + } else if constexpr (MODE == FilteredTopKMode::PageTable) { + row_output[pos] = (idx != ~0u) ? src_page_entry[idx] : static_cast(-1); + } else { // Ragged + row_output[pos] = + (idx != ~0u) ? static_cast(idx) + offset : static_cast(-1); + } + } + } } -template -cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, const IdType* offsets, - IdType* lengths, uint32_t num_rows, uint32_t top_k_val, - uint32_t max_len, cudaStream_t stream = 0) { - constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC; - constexpr int MAX_VEC = 16 / sizeof(DType); +template +cudaError_t LaunchSortTopKByIndex(IdType* output_indices, DType* output_values, + const IdType* aux_input, int64_t aux_stride, + const IdType* row_to_batch, uint32_t num_rows, uint32_t top_k_val, + uint32_t max_len, cudaStream_t stream = 0) { + // Block-local sort variants cover at most 256 * 8 = 2048 elements. + if (top_k_val > 2048) { + return cudaErrorInvalidValue; + } + if constexpr (MODE == FilteredTopKMode::Plain) { + if (top_k_val <= 1) { + return cudaSuccess; + } + } + if (top_k_val == 0) { + return cudaSuccess; + } dim3 grid(num_rows); - dim3 block(FILTERED_TOPK_BLOCK_THREADS); - DType* aux_output = nullptr; // Not used for Ragged mode - int64_t aux_stride = 0; // Not used for Ragged mode - const IdType* row_to_batch = nullptr; // Not used for Ragged mode - void* args[] = {&input, &output_indices, &aux_output, &offsets, &aux_stride, - &row_to_batch, &lengths, &num_rows, &top_k_val, &max_len}; - - const int vec_size = ComputeFilteredTopKVecSize(max_len); + void* args[] = {&output_indices, &output_values, &aux_input, &aux_stride, + &row_to_batch, &top_k_val, &max_len}; + auto launch_sort = [&](auto kernel, uint32_t threads) -> cudaError_t { + dim3 block(threads); + return cudaLaunchKernel((void*)kernel, grid, block, args, 0, stream); + }; -#define DISPATCH_VEC_SIZE(VS) \ - if (vec_size == VS) { \ - auto kernel = FilteredTopKUnifiedKernel; \ - FLASHINFER_CUDA_CALL( \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, smem_size, stream)); \ - return cudaSuccess; \ + cudaError_t status; + if (top_k_val <= 128) { + status = launch_sort(SortTopKByIndexKernel, 32); + } else if (top_k_val <= 256) { + status = launch_sort(SortTopKByIndexKernel, 32); + } else if (top_k_val <= 512) { + status = launch_sort(SortTopKByIndexKernel, 64); + } else if (top_k_val <= 576) { + status = launch_sort(SortTopKByIndexKernel, 64); + } else if (top_k_val <= 1024) { + status = launch_sort(SortTopKByIndexKernel, 128); + } else { + status = launch_sort(SortTopKByIndexKernel, 256); } + return status; +} - DISPATCH_VEC_SIZE(1) - DISPATCH_VEC_SIZE(2) - DISPATCH_VEC_SIZE(4) - if constexpr (MAX_VEC >= 8) { - DISPATCH_VEC_SIZE(8) +/*! + * \brief CUB stable radix sort: sorts top-k by value descending, carrying indices. + * + * Uses 32-bit flipped ordered value as key and 32-bit index as satellite data. + * Since radix sort is stable, equal values preserve their prior relative order. + * When preceded by an index sort, this yields (value desc, index asc) ordering. + */ +template +__global__ void __launch_bounds__(BLOCK_THREADS) + StableSortTopKByValueKernel(IdType* output_indices, DType* output_values, uint32_t k, + uint32_t /*max_len*/) { + using Traits = RadixTopKTraits; + using OrderedType = typename Traits::OrderedType; + using BlockRadixSortT = cub::BlockRadixSort; + __shared__ typename BlockRadixSortT::TempStorage temp_storage; + + const uint32_t row = blockIdx.x; + const uint32_t tx = threadIdx.x; + + IdType* row_indices = output_indices + static_cast(row) * k; + DType* row_values = output_values + static_cast(row) * k; + + uint32_t keys[ITEMS_PER_THREAD]; + uint32_t indices[ITEMS_PER_THREAD]; + +#pragma unroll + for (uint32_t i = 0; i < ITEMS_PER_THREAD; i++) { + uint32_t pos = tx * ITEMS_PER_THREAD + i; + if (pos < k) { + OrderedType ordered = Traits::ToOrdered(row_values[pos]); + keys[i] = static_cast(static_cast(~ordered)); + indices[i] = static_cast(row_indices[pos]); + } else { + keys[i] = ~0u; + indices[i] = ~0u; + } } -#undef DISPATCH_VEC_SIZE - return cudaSuccess; + constexpr int end_bit = sizeof(OrderedType) * 8; + BlockRadixSortT(temp_storage).Sort(keys, indices, 0, end_bit); + +#pragma unroll + for (uint32_t i = 0; i < ITEMS_PER_THREAD; i++) { + uint32_t pos = tx * ITEMS_PER_THREAD + i; + if (pos < k) { + row_indices[pos] = static_cast(indices[i]); + OrderedType ordered = static_cast(~static_cast(keys[i])); + row_values[pos] = Traits::FromOrdered(ordered); + } + } } template -cudaError_t FilteredTopK(DType* input, IdType* output_indices, DType* output_values, - const IdType* lengths, uint32_t num_rows, uint32_t top_k_val, - uint32_t max_len, cudaStream_t stream = 0) { +cudaError_t StableSortTopKByValue(IdType* output_indices, DType* output_values, uint32_t num_rows, + uint32_t top_k_val, uint32_t max_len, cudaStream_t stream = 0) { + // Block-local sort variants cover at most 256 * 8 = 2048 elements. + if (top_k_val > 2048) { + return cudaErrorInvalidValue; + } + if (top_k_val <= 1) { + return cudaSuccess; + } + + dim3 grid(num_rows); + void* args[] = {&output_indices, &output_values, &top_k_val, &max_len}; + auto launch_sort = [&](auto kernel, uint32_t threads) -> cudaError_t { + dim3 block(threads); + return cudaLaunchKernel((void*)kernel, grid, block, args, 0, stream); + }; + + cudaError_t status; + if (top_k_val <= 128) { + status = launch_sort(StableSortTopKByValueKernel<32, 4, IdType, DType>, 32); + } else if (top_k_val <= 256) { + status = launch_sort(StableSortTopKByValueKernel<32, 8, IdType, DType>, 32); + } else if (top_k_val <= 512) { + status = launch_sort(StableSortTopKByValueKernel<64, 8, IdType, DType>, 64); + } else if (top_k_val <= 576) { + status = launch_sort(StableSortTopKByValueKernel<64, 9, IdType, DType>, 64); + } else if (top_k_val <= 1024) { + status = launch_sort(StableSortTopKByValueKernel<128, 8, IdType, DType>, 128); + } else { + status = launch_sort(StableSortTopKByValueKernel<256, 8, IdType, DType>, 256); + } + return status; +} + +template +cudaError_t LaunchFilteredTopKUnified(DType* input, IdType* output, DType* aux_output, + const IdType* aux_input, int64_t aux_stride, + const IdType* row_to_batch, const IdType* lengths, + uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + bool deterministic = false, cudaStream_t stream = 0) { constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC; constexpr int MAX_VEC = 16 / sizeof(DType); dim3 grid(num_rows); dim3 block(FILTERED_TOPK_BLOCK_THREADS); - const IdType* aux_input = nullptr; // Not used for Plain mode - int64_t aux_stride = 0; // Not used for Plain mode - const IdType* row_to_batch = nullptr; // Not used for Plain mode - void* args[] = {&input, &output_indices, &output_values, &aux_input, &aux_stride, - &row_to_batch, &lengths, &num_rows, &top_k_val, &max_len}; + void* args[] = {&input, &output, &aux_output, &aux_input, &aux_stride, + &row_to_batch, &lengths, &num_rows, &top_k_val, &max_len}; const int vec_size = ComputeFilteredTopKVecSize(max_len); -#define DISPATCH_VEC_SIZE(VS) \ - if (vec_size == VS) { \ - auto kernel = FilteredTopKUnifiedKernel; \ - FLASHINFER_CUDA_CALL( \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, smem_size, stream)); \ - return cudaSuccess; \ +#define DISPATCH_VEC_SIZE(VS) \ + if (vec_size == VS) { \ + if (!deterministic) { \ + auto kernel = FilteredTopKUnifiedKernel; \ + FLASHINFER_CUDA_CALL( \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, smem_size, stream)); \ + } else { \ + auto kernel = FilteredTopKUnifiedKernel; \ + FLASHINFER_CUDA_CALL( \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, smem_size, stream)); \ + } \ + return cudaSuccess; \ } DISPATCH_VEC_SIZE(1) @@ -2575,6 +3076,44 @@ cudaError_t FilteredTopK(DType* input, IdType* output_indices, DType* output_val return cudaSuccess; } +// Launch functions with VEC_SIZE and BLOCK_THREADS dispatch - using unified kernel +template +cudaError_t FilteredTopKPageTableTransform(DType* input, IdType* output_page_table, + const IdType* src_page_table, int64_t src_stride, + const IdType* row_to_batch, IdType* lengths, + uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + bool deterministic = false, cudaStream_t stream = 0) { + DType* aux_output = nullptr; // Not used for PageTable mode + return LaunchFilteredTopKUnified( + input, output_page_table, aux_output, src_page_table, src_stride, row_to_batch, lengths, + num_rows, top_k_val, max_len, deterministic, stream); +} + +template +cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, const IdType* offsets, + IdType* lengths, uint32_t num_rows, uint32_t top_k_val, + uint32_t max_len, bool deterministic = false, + cudaStream_t stream = 0) { + DType* aux_output = nullptr; // Not used for Ragged mode + int64_t aux_stride = 0; // Not used for Ragged mode + const IdType* row_to_batch = nullptr; // Not used for Ragged mode + return LaunchFilteredTopKUnified( + input, output_indices, aux_output, offsets, aux_stride, row_to_batch, lengths, num_rows, + top_k_val, max_len, deterministic, stream); +} + +template +cudaError_t FilteredTopK(DType* input, IdType* output_indices, DType* output_values, + const IdType* lengths, uint32_t num_rows, uint32_t top_k_val, + uint32_t max_len, bool deterministic = false, cudaStream_t stream = 0) { + const IdType* aux_input = nullptr; // Not used for Plain mode + int64_t aux_stride = 0; // Not used for Plain mode + const IdType* row_to_batch = nullptr; // Not used for Plain mode + return LaunchFilteredTopKUnified( + input, output_indices, output_values, aux_input, aux_stride, row_to_batch, lengths, num_rows, + top_k_val, max_len, deterministic, stream); +} + /*! * \brief Check if the GPU supports enough shared memory for FilteredTopK algorithm. * @@ -2612,19 +3151,12 @@ inline TopKAlgoOverride GetTopKAlgoOverride() { * \param num_rows Number of rows (batch size) * \param top_k_val Number of top elements to select * \param max_len Maximum sequence length + * \param deterministic Whether deterministic top-k path is requested * \return true if FilteredTopK should be used, false for Multi-CTA RadixTopK - * - * Heuristics: - * - 16-bit types (fp16/bf16): FilteredTopK for seq <= 16K - * - 32-bit types (fp32): FilteredTopK for seq <= 32K, or larger seq with batch > seq/16K - * - * Note: - * - For tie-heavy long-sequence workloads, FilteredTopK can hit threshold-bin overflow - * in refinement and fall back to slower handling, which may cause visible performance - * degradation. In such cases, users can force multi-CTA via FLASHINFER_TOPK_ALGO=multi_cta. */ template -inline bool ShouldUseFilteredTopK(uint32_t num_rows, uint32_t top_k_val, uint32_t max_len) { +inline bool ShouldUseFilteredTopK(uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + bool deterministic) { // Check if GPU supports enough shared memory for FilteredTopK const bool gpu_supports_filtered = CanImplementFilteredTopK(); const bool k_fits_filtered = (top_k_val <= FILTERED_TOPK_MAX_K) && (max_len > top_k_val); @@ -2638,12 +3170,24 @@ inline bool ShouldUseFilteredTopK(uint32_t num_rows, uint32_t top_k_val, uint32_ if (algo_override == TopKAlgoOverride::FILTERED) return true; if (algo_override == TopKAlgoOverride::MULTI_CTA) return false; - // Auto heuristics based on dtype + // 16-bit types: simpler threshold + // 32-bit types: more nuanced heuristic + if (deterministic) { + if constexpr (sizeof(DType) <= 2) { + return num_rows > (max_len / 256); + } else { + if (max_len <= 16384) { + return true; + } else { + const uint32_t batch_threshold = std::min(64u, std::max(16u, max_len / 4096)); + return num_rows >= batch_threshold; + } + } + } + if constexpr (sizeof(DType) <= 2) { - // 16-bit types: simpler threshold at 16K return (max_len <= 16384); } else { - // 32-bit types: more nuanced heuristic if (max_len <= 32768) { return true; } else { @@ -2659,42 +3203,69 @@ cudaError_t TopKPageTableTransformDispatch(DType* input, IdType* output_page_tab const IdType* src_page_table, int64_t src_stride, const IdType* row_to_batch, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, - RadixRowState* row_states_buffer, + RadixRowState* row_states_buffer, bool deterministic, cudaStream_t stream = 0) { - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len)) { - return FilteredTopKPageTableTransform(input, output_page_table, src_page_table, - src_stride, row_to_batch, lengths, - num_rows, top_k_val, max_len, stream); + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic)) { + FLASHINFER_CUDA_CALL((FilteredTopKPageTableTransform( + input, output_page_table, src_page_table, src_stride, row_to_batch, lengths, num_rows, + top_k_val, max_len, deterministic, stream))); + if (deterministic) { + FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( + output_page_table, static_cast(nullptr), src_page_table, src_stride, + row_to_batch, num_rows, top_k_val, max_len, stream))); + } + return cudaSuccess; } return RadixTopKPageTableTransformMultiCTA( input, output_page_table, src_page_table, src_stride, row_to_batch, lengths, num_rows, - top_k_val, max_len, row_states_buffer, stream); + top_k_val, max_len, row_states_buffer, deterministic, stream); } template cudaError_t TopKRaggedTransformDispatch(DType* input, IdType* output_indices, const IdType* offsets, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, RadixRowState* row_states_buffer, - cudaStream_t stream = 0) { - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len)) { - return FilteredTopKRaggedTransform(input, output_indices, offsets, lengths, - num_rows, top_k_val, max_len, stream); + bool deterministic, cudaStream_t stream = 0) { + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic)) { + FLASHINFER_CUDA_CALL((FilteredTopKRaggedTransform( + input, output_indices, offsets, lengths, num_rows, top_k_val, max_len, deterministic, + stream))); + if (deterministic) { + FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( + output_indices, static_cast(nullptr), offsets, 0, nullptr, num_rows, top_k_val, + max_len, stream))); + } + return cudaSuccess; } return RadixTopKRaggedTransformMultiCTA(input, output_indices, offsets, lengths, num_rows, top_k_val, max_len, - row_states_buffer, stream); + row_states_buffer, deterministic, stream); } template cudaError_t TopKDispatch(DType* input, IdType* output_indices, DType* output_values, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, - RadixRowState* row_states_buffer, cudaStream_t stream = 0) { - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len)) { - return FilteredTopK(input, output_indices, output_values, nullptr, num_rows, - top_k_val, max_len, stream); + RadixRowState* row_states_buffer, bool sorted_output = false, + bool deterministic = false, cudaStream_t stream = 0) { + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic)) { + FLASHINFER_CUDA_CALL( + (FilteredTopK(input, output_indices, output_values, nullptr, num_rows, + top_k_val, max_len, deterministic, stream))); + if (deterministic) { + FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( + output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, + stream))); + } + } else { + FLASHINFER_CUDA_CALL((RadixTopKMultiCTA( + input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len, + row_states_buffer, deterministic, stream))); + } + if (sorted_output) { + FLASHINFER_CUDA_CALL((StableSortTopKByValue( + output_indices, output_values, num_rows, top_k_val, max_len, stream))); } - return RadixTopKMultiCTA(input, output_indices, output_values, nullptr, num_rows, - top_k_val, max_len, row_states_buffer, stream); + return cudaSuccess; } } // namespace sampling diff --git a/tests/utils/test_topk.py b/tests/utils/test_topk.py index 094459cea1..ae622b9e7a 100644 --- a/tests/utils/test_topk.py +++ b/tests/utils/test_topk.py @@ -20,6 +20,7 @@ import torch import flashinfer +import flashinfer.utils as flashinfer_utils from flashinfer.topk import can_implement_filtered_topk from flashinfer.utils import get_compute_capability @@ -77,6 +78,27 @@ def verify_topk_correctness(logits, values, indices, k): return True +def _get_cached_topk_row_states_buffer(device: torch.device): + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + key = (f"radix_topk_row_states_{device}", device) + return flashinfer_utils._cache_buf.get(key) + + +def _clear_cached_topk_row_states_buffer(device: torch.device): + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + key = (f"radix_topk_row_states_{device}", device) + flashinfer_utils._cache_buf.pop(key, None) + + +def _build_strictly_descending_logits( + num_rows: int, vocab_size: int, device: torch.device +) -> torch.Tensor: + base = torch.arange(vocab_size, 0, -1, device=device, dtype=torch.float32) + return base.unsqueeze(0).repeat(num_rows, 1).contiguous() + + @pytest.mark.parametrize("batch_size", [1, 16, 64]) @pytest.mark.parametrize("vocab_size", [32000, 65536, 128512]) @pytest.mark.parametrize("k", [256, 512, 1024]) @@ -179,13 +201,14 @@ def test_top_k_single_batch(vocab_size, k): @pytest.mark.parametrize("batch_size", [64, 128]) @pytest.mark.parametrize("vocab_size", [65536, 128512]) @pytest.mark.parametrize("k", [256]) -def test_top_k_large_batch(batch_size, vocab_size, k): +@pytest.mark.parametrize("det", [True, False]) +def test_top_k_large_batch(batch_size, vocab_size, k, det): """Test top_k with large batch sizes (multi-CTA path).""" torch.manual_seed(42) logits = torch.randn(batch_size, vocab_size, device="cuda", dtype=torch.float32) # flashinfer top_k (should use multi-CTA path for large vocab) - values, indices = flashinfer.top_k(logits, k) + values, indices = flashinfer.top_k(logits, k, deterministic=det) # Reference: torch.topk ref_values, ref_indices = torch.topk(logits, k, dim=-1) @@ -199,6 +222,87 @@ def test_top_k_large_batch(batch_size, vocab_size, k): assert accuracy >= 0.98, f"Accuracy {accuracy:.4f} < 0.98" +@pytest.mark.parametrize("api_kind", ["top_k", "page_table", "ragged"]) +@pytest.mark.parametrize( + ("first_deterministic", "second_deterministic"), + [(False, True), (True, False)], +) +def test_multi_cta_reuses_dirty_cached_row_states_buffer_across_mode_transitions( + api_kind, set_topk_algo, first_deterministic, second_deterministic +): + set_topk_algo("multi_cta") + device = torch.device("cuda") + _clear_cached_topk_row_states_buffer(device) + + batch_size = 4 + vocab_size = 131072 + k = 512 + logits = _build_strictly_descending_logits(batch_size, vocab_size, device) + + if api_kind == "top_k": + expected_values = logits[:, :k] + expected_indices = ( + torch.arange(k, device=device, dtype=torch.int64) + .unsqueeze(0) + .expand(batch_size, -1) + ) + + values_a, indices_a = flashinfer.top_k( + logits, k, sorted=True, deterministic=first_deterministic + ) + torch.testing.assert_close(values_a, expected_values) + assert torch.equal(indices_a, expected_indices) + buf_a = _get_cached_topk_row_states_buffer(device) + assert buf_a is not None + + values_b, indices_b = flashinfer.top_k( + logits, k, sorted=True, deterministic=second_deterministic + ) + torch.testing.assert_close(values_b, expected_values) + assert torch.equal(indices_b, expected_indices) + else: + lengths = torch.full( + (batch_size,), vocab_size, device=device, dtype=torch.int32 + ) + expected = torch.arange(k, device=device, dtype=torch.int32).unsqueeze(0) + expected = expected.expand(batch_size, -1) + src_page_table = None + offsets = None + + if api_kind == "ragged": + offsets = torch.arange( + 0, batch_size * vocab_size, vocab_size, device=device, dtype=torch.int32 + ) + expected = offsets.unsqueeze(1) + expected + + output_a = _run_transform( + logits, + k, + api_kind, + lengths=lengths, + deterministic=first_deterministic, + src_page_table=src_page_table, + offsets=offsets, + ) + _assert_unordered_indices_match(output_a, expected) + buf_a = _get_cached_topk_row_states_buffer(device) + assert buf_a is not None + + output_b = _run_transform( + logits, + k, + api_kind, + lengths=lengths, + deterministic=second_deterministic, + src_page_table=src_page_table, + offsets=offsets, + ) + _assert_unordered_indices_match(output_b, expected) + + buf_b = _get_cached_topk_row_states_buffer(device) + assert buf_b is buf_a + + @pytest.mark.parametrize("k", [256, 1024, 2048]) def test_top_k_large_k(k): """Test top_k with larger k values.""" @@ -1380,23 +1484,175 @@ def _assert_unordered_indices_match(output, expected): ) -def _run_transform_with_identity_mapping(logits, k, transform_mode): - """Run transform API with identity mapping so output equals selected indices.""" +def _assert_top_k_matches_torch( + logits: torch.Tensor, k: int, *, deterministic: bool = False, sorted: bool = True +): + """Assert FlashInfer top_k matches torch.topk for exact-order cases.""" + values, indices = flashinfer.top_k( + logits, k, deterministic=deterministic, sorted=sorted + ) + ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=sorted) + + assert values.shape == ref_values.shape + assert indices.shape == ref_indices.shape + torch.testing.assert_close(values, ref_values) + assert torch.equal(indices, ref_indices) + + +def _run_transform( + logits, + k, + transform_mode, + *, + lengths: torch.Tensor | None = None, + deterministic: bool = False, + src_page_table: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, +): + """Run a transform API with either explicit or default identity metadata.""" batch_size, vocab_size = logits.shape device = logits.device - lengths = torch.full((batch_size,), vocab_size, device=device, dtype=torch.int32) + if lengths is None: + lengths = torch.full( + (batch_size,), vocab_size, device=device, dtype=torch.int32 + ) if transform_mode == "page_table": - src_page_table = ( - torch.arange(vocab_size, device=device, dtype=torch.int32) - .unsqueeze(0) - .repeat(batch_size, 1) - .contiguous() + if src_page_table is None: + src_page_table = ( + torch.arange(vocab_size, device=device, dtype=torch.int32) + .unsqueeze(0) + .repeat(batch_size, 1) + .contiguous() + ) + return flashinfer.top_k_page_table_transform( + logits, src_page_table, lengths, k, deterministic=deterministic ) - return flashinfer.top_k_page_table_transform(logits, src_page_table, lengths, k) - offsets = torch.zeros((batch_size,), device=device, dtype=torch.int32) - return flashinfer.top_k_ragged_transform(logits, offsets, lengths, k) + if offsets is None: + offsets = torch.zeros((batch_size,), device=device, dtype=torch.int32) + return flashinfer.top_k_ragged_transform( + logits, offsets, lengths, k, deterministic=deterministic + ) + + +def _run_transform_with_identity_mapping( + logits, k, transform_mode, deterministic: bool = False +): + """Run transform API with identity mapping so output equals selected indices.""" + return _run_transform(logits, k, transform_mode, deterministic=deterministic) + + +def _assert_transform_identity_matches_torch( + logits, k, transform_mode, deterministic: bool = False +): + """Assert transform output matches torch.topk indices under identity mapping.""" + output = _run_transform_with_identity_mapping( + logits, k, transform_mode, deterministic=deterministic + ) + ref_indices = torch.topk(logits, k, dim=-1, sorted=True).indices.to(torch.int32) + _assert_unordered_indices_match(output, ref_indices) + + +def _assert_repeatable_transform_output( + logits, + k, + transform_mode, + *, + num_runs: int, + deterministic: bool = True, + lengths: torch.Tensor | None = None, + src_page_table: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, +): + """Assert a transform API produces bitwise-identical output across repeated runs.""" + ref = _run_transform( + logits, + k, + transform_mode, + lengths=lengths, + deterministic=deterministic, + src_page_table=src_page_table, + offsets=offsets, + ) + for _ in range(num_runs - 1): + out = _run_transform( + logits, + k, + transform_mode, + lengths=lengths, + deterministic=deterministic, + src_page_table=src_page_table, + offsets=offsets, + ) + assert torch.equal(out, ref) + return ref + + +def _assert_repeatable_valid_identity_transform_selection( + output_a: torch.Tensor, + output_b: torch.Tensor, + vocab_size: int, + k: int, + gt_count: int = 0, +): + """Assert deterministic transform outputs are repeatable and form a valid top-k set.""" + assert torch.equal(output_a, output_b) + output = output_a[0] + assert output.numel() == k + assert torch.unique(output).numel() == k + assert torch.all((output >= 0) & (output < vocab_size)) + + if gt_count > 0: + gt_indices = torch.arange( + vocab_size - gt_count, + vocab_size, + device=output.device, + dtype=torch.int32, + ) + gt_mask = torch.isin(output, gt_indices) + assert gt_mask.sum().item() == gt_count + assert torch.all(torch.isin(gt_indices, output)) + tie_selected = output[~gt_mask] + assert tie_selected.numel() == k - gt_count + assert torch.all(tie_selected < vocab_size - gt_count) + + +def _assert_repeatable_valid_topk_selection( + logits: torch.Tensor, + values_a: torch.Tensor, + indices_a: torch.Tensor, + values_b: torch.Tensor, + indices_b: torch.Tensor, + k: int, + gt_count: int = 0, +): + """Assert deterministic top-k outputs are repeatable and form a valid selected set.""" + assert torch.equal(values_a, values_b) + assert torch.equal(indices_a, indices_b) + + gathered_values = torch.gather(logits, 1, indices_a) + torch.testing.assert_close(values_a, gathered_values) + + vocab_size = logits.size(1) + for output in indices_a: + assert output.numel() == k + assert torch.unique(output).numel() == k + assert torch.all((output >= 0) & (output < vocab_size)) + + if gt_count > 0: + gt_indices = torch.arange( + vocab_size - gt_count, + vocab_size, + device=output.device, + dtype=output.dtype, + ) + gt_mask = torch.isin(output, gt_indices) + assert gt_mask.sum().item() == gt_count + assert torch.all(torch.isin(gt_indices, output)) + tie_selected = output[~gt_mask] + assert tie_selected.numel() == k - gt_count + assert torch.all(tie_selected < vocab_size - gt_count) @pytest.mark.parametrize("transform_mode", ["page_table", "ragged"]) @@ -1430,81 +1686,550 @@ def test_bf16_long_seq_transform_regression_filtered(transform_mode, set_topk_al _assert_unordered_indices_match(output, expected) -@pytest.mark.parametrize("algo", ["auto", "multi_cta", "filtered"]) -def test_fp32_long_seq_refine_overflow_regression_across_algorithms( - algo, set_topk_algo -): - """Regression for float32 long-seq refine overflow across algorithms.""" +@pytest.mark.parametrize( + ("builder", "algo"), + [ + (_build_fp32_long_seq_overflow_inputs, "auto"), + (_build_fp32_long_seq_overflow_inputs, "multi_cta"), + (_build_fp32_long_seq_overflow_inputs, "filtered"), + (_build_fp32_long_seq_pivot_mismatch_inputs, "filtered"), + ], + ids=[ + "refine_overflow-auto", + "refine_overflow-multi_cta", + "refine_overflow-filtered", + "pivot_rebuild-filtered", + ], +) +@pytest.mark.parametrize("api_kind", ["top_k", "page_table", "ragged"]) +def test_fp32_long_seq_regression_matrix(builder, algo, api_kind, set_topk_algo): + """Long-sequence fp32 regressions should remain exact across supported APIs.""" if algo == "filtered" and not can_implement_filtered_topk(): pytest.skip("Filtered top-k not supported on this device") set_topk_algo(algo) - logits, batch_size, _, k = _build_fp32_long_seq_overflow_inputs() + logits, _, _, k = builder() + if api_kind == "top_k": + _assert_top_k_matches_torch(logits, k, sorted=True) + else: + _assert_transform_identity_matches_torch(logits, k, api_kind) + + +@pytest.mark.parametrize( + ("builder", "case_name"), + [ + (_build_fp32_long_seq_overflow_inputs, "refine_overflow"), + (_build_fp32_long_seq_pivot_mismatch_inputs, "pivot_rebuild"), + ], +) +@pytest.mark.parametrize("api_kind", ["top_k", "page_table", "ragged"]) +def test_fp32_long_seq_filtered_deterministic_regression_matrix( + builder, case_name, api_kind, set_topk_algo +): + """Filtered deterministic long-sequence fallback paths should remain exact.""" + if not can_implement_filtered_topk(): + pytest.skip("Filtered top-k not supported on this device") - values, indices = flashinfer.top_k(logits, k, sorted=True) - ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=True) + set_topk_algo("filtered") + logits, _, _, k = builder() + if api_kind == "top_k": + _assert_top_k_matches_torch(logits, k, deterministic=True, sorted=True) + else: + _assert_transform_identity_matches_torch( + logits, k, api_kind, deterministic=True + ) - assert values.shape == (batch_size, k) - assert indices.shape == (batch_size, k) - torch.testing.assert_close(values, ref_values) - assert torch.equal(indices, ref_indices) +def test_top_k_deterministic_across_streams(): + """deterministic=True should be repeatable across CUDA streams. -@pytest.mark.parametrize("algo", ["auto", "multi_cta", "filtered"]) -@pytest.mark.parametrize("transform_mode", ["page_table", "ragged"]) -def test_fp32_long_seq_refine_overflow_transform_regression_across_algorithms( - algo, transform_mode, set_topk_algo + This runs the same deterministic top-k on two non-default streams (sequentially) + and checks for bitwise-identical results. + """ + batch_size = 4 + vocab_size = 16384 + k = 256 + device = "cuda" + + torch.manual_seed(0) + logits = torch.randn(batch_size, vocab_size, device=device, dtype=torch.float32) + + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + + with torch.cuda.stream(s1): + values_a, indices_a = flashinfer.top_k( + logits, k, deterministic=True, sorted=False + ) + s1.synchronize() + + with torch.cuda.stream(s2): + values_b, indices_b = flashinfer.top_k( + logits, k, deterministic=True, sorted=False + ) + s2.synchronize() + + assert torch.equal(values_a, values_b) + assert torch.equal(indices_a, indices_b) + + +@pytest.mark.parametrize( + ("algo", "batch_size", "vocab_size", "k", "dtype", "pattern_mod"), + [ + ("auto", 4, 16384, 256, torch.float32, 32), + # A 4096-wide fp32 row keeps ctas_per_group == 1 even under the multi_cta + # override, so this still exercises the radix single-CTA branch. + ("multi_cta", 4, 4096, 256, torch.float32, 32), + ("multi_cta", 1, 131072, 1024, torch.bfloat16, 64), + ("filtered", 4, 16384, 256, torch.float32, 32), + ], +) +def test_top_k_deterministic_repeatability_matrix( + algo, batch_size, vocab_size, k, dtype, pattern_mod, set_topk_algo ): - """Regression for fp32 long-seq overflow on transform APIs.""" + """deterministic=True should be bitwise identical across routing modes.""" if algo == "filtered" and not can_implement_filtered_topk(): pytest.skip("Filtered top-k not supported on this device") + if dtype == torch.bfloat16: + _require_sm80_for_bf16() set_topk_algo(algo) - logits, _, _, k = _build_fp32_long_seq_overflow_inputs() - output = _run_transform_with_identity_mapping(logits, k, transform_mode) - ref_indices = torch.topk(logits, k, dim=-1, sorted=True).indices.to(torch.int32) - _assert_unordered_indices_match(output, ref_indices) + num_runs = 20 + device = "cuda" + pattern = ( + torch.arange(vocab_size, device=device, dtype=torch.float32) % pattern_mod + ) / float(pattern_mod) + logits = pattern.unsqueeze(0).repeat(batch_size, 1).to(dtype).contiguous() + + ref_values, ref_indices = flashinfer.top_k( + logits, k, deterministic=True, sorted=False + ) + for _ in range(num_runs - 1): + values, indices = flashinfer.top_k(logits, k, deterministic=True, sorted=False) + assert torch.equal(values, ref_values) + assert torch.equal(indices, ref_indices) + + +@pytest.mark.parametrize( + ("algo", "batch_size", "vocab_size", "k"), + [ + ("auto", 4, 16384, 256), + # A 4096-wide fp32 row keeps ctas_per_group == 1 even under the multi_cta + # override, so this still exercises the radix single-CTA branch. + ("multi_cta", 4, 4096, 256), + ("multi_cta", 1, 131072, 1024), + ("filtered", 4, 16384, 256), + ], +) +def test_top_k_deterministic_sorted_matches_stable_sort( + algo, batch_size, vocab_size, k, set_topk_algo +): + """sorted=True should be repeatable, valid, and descending.""" + if algo == "filtered" and not can_implement_filtered_topk(): + pytest.skip("Filtered top-k not supported on this device") + + set_topk_algo(algo) + device = "cuda" + pattern = (torch.arange(vocab_size, device=device, dtype=torch.float32) % 32) / 32.0 + logits = pattern.unsqueeze(0).repeat(batch_size, 1).contiguous() + sorted_values_a, sorted_indices_a = flashinfer.top_k( + logits, k, deterministic=True, sorted=True + ) + sorted_values_b, sorted_indices_b = flashinfer.top_k( + logits, k, deterministic=True, sorted=True + ) -def test_fp32_long_seq_pivot_rebuild_regression_filtered(set_topk_algo): - """Regression for pivot reconstruction in float32 overflow fallback.""" - if not can_implement_filtered_topk(): + _assert_repeatable_valid_topk_selection( + logits, sorted_values_a, sorted_indices_a, sorted_values_b, sorted_indices_b, k + ) + assert torch.all(sorted_values_a[:, :-1] >= sorted_values_a[:, 1:]) + + +@pytest.mark.parametrize( + ("algo", "vocab_size"), + [ + ("auto", 16384), + # A 4096-wide fp32 row keeps ctas_per_group == 1 even under the multi_cta + # override, so this still exercises the radix single-CTA branch. + ("multi_cta", 4096), + ("multi_cta", 131072), + ("filtered", 16384), + ], +) +@pytest.mark.parametrize(("pattern", "k"), [("all_equal", 8), ("pivot_tie", 6)]) +def test_top_k_deterministic_sorted_repeatable_valid_selection_under_ties( + algo, vocab_size, pattern, k, set_topk_algo +): + """Deterministic sorted top-k should remain repeatable under tie pressure.""" + if algo == "filtered" and not can_implement_filtered_topk(): pytest.skip("Filtered top-k not supported on this device") - set_topk_algo("filtered") - logits, batch_size, _, k = _build_fp32_long_seq_pivot_mismatch_inputs() + set_topk_algo(algo) + device = "cuda" + logits = torch.ones((1, vocab_size), device=device, dtype=torch.float16) + gt_count = 0 - values, indices = flashinfer.top_k(logits, k, sorted=True) + if pattern == "all_equal": + expected_values = torch.ones((1, k), device=device, dtype=torch.float16) + else: + gt_count = 2 + logits[:, vocab_size - gt_count :] = 2.0 + expected_values = torch.cat( + [ + torch.full((1, gt_count), 2.0, device=device, dtype=torch.float16), + torch.ones((1, k - gt_count), device=device, dtype=torch.float16), + ], + dim=-1, + ) + + values_a, indices_a = flashinfer.top_k(logits, k, deterministic=True, sorted=True) + values_b, indices_b = flashinfer.top_k(logits, k, deterministic=True, sorted=True) + + torch.testing.assert_close(values_a, expected_values) + _assert_repeatable_valid_topk_selection( + logits, values_a, indices_a, values_b, indices_b, k, gt_count=gt_count + ) + + +@pytest.mark.parametrize( + ("algo", "vocab_size", "k"), + [ + ("auto", 131072, 4096), + ("multi_cta", 131072, 4096), + # Keep one filtered-specific large-k coverage row that still satisfies + # FILTERED_TOPK_MAX_K and therefore actually routes to FilteredTopK. + ("filtered", 131072, 2048), + ], + ids=[ + "auto_k4096", + "multi_cta_k4096", + "filtered_k2048", + ], +) +def test_top_k_deterministic_sorted_large_k_matches_torch_by_algo( + algo, vocab_size, k, set_topk_algo +): + """Deterministic sorted output should match torch.topk across routed large-k cases.""" + set_topk_algo(algo) + + if algo == "filtered" and not can_implement_filtered_topk(): + pytest.skip("GPU does not support filtered topk (requires 128KB shared memory)") + + batch_size = 1 + device = "cuda" + + torch.manual_seed(0) + logits = torch.randn(batch_size, vocab_size, device=device, dtype=torch.float32) + + values, indices = flashinfer.top_k(logits, k, deterministic=True, sorted=True) ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=True) - assert values.shape == (batch_size, k) - assert indices.shape == (batch_size, k) torch.testing.assert_close(values, ref_values) assert torch.equal(indices, ref_indices) +@pytest.mark.parametrize("algo", ["auto", "multi_cta"]) +def test_top_k_deterministic_trivial_k_equals_length_by_algo(algo, set_topk_algo): + """Deterministic k==length fast paths should remain exact across auto/radix routing.""" + set_topk_algo(algo) + + batch_size = 2 + vocab_size = 131072 + k = vocab_size + device = "cuda" + + torch.manual_seed(0) + logits = torch.randn(batch_size, vocab_size, device=device, dtype=torch.float16) + + values, indices = flashinfer.top_k(logits, k, deterministic=True, sorted=False) + expected_indices = ( + torch.arange(vocab_size, device=device, dtype=torch.int64) + .unsqueeze(0) + .expand(batch_size, -1) + ) + + assert torch.equal(indices, expected_indices) + torch.testing.assert_close(values, logits) + + @pytest.mark.parametrize("transform_mode", ["page_table", "ragged"]) -def test_fp32_long_seq_pivot_rebuild_transform_regression_filtered( +def test_top_k_transform_multi_cta_deterministic_trivial_lengths( transform_mode, set_topk_algo ): - """Regression for fp32 pivot reconstruction in filtered transform APIs.""" + """Deterministic radix transform should handle length == k and length < k fast paths.""" + set_topk_algo("multi_cta") + + num_rows = 2 + max_len = 131072 + k = 256 + device = "cuda" + + torch.manual_seed(0) + scores = torch.randn(num_rows, max_len, device=device, dtype=torch.float16) + lengths = torch.tensor([k, k // 2], device=device, dtype=torch.int32) + + if transform_mode == "page_table": + src_page_table = ( + torch.arange(max_len, device=device, dtype=torch.int32) + .mul(3) + .add(7) + .unsqueeze(0) + .repeat(num_rows, 1) + .contiguous() + ) + output = flashinfer.top_k_page_table_transform( + scores, src_page_table, lengths, k, deterministic=True + ) + expected = torch.full((num_rows, k), -1, device=device, dtype=torch.int32) + expected[0] = src_page_table[0, :k] + expected[1, : k // 2] = src_page_table[1, : k // 2] + else: + offsets = torch.tensor([0, 1000], device=device, dtype=torch.int32) + output = flashinfer.top_k_ragged_transform( + scores, offsets, lengths, k, deterministic=True + ) + expected = torch.full((num_rows, k), -1, device=device, dtype=torch.int32) + expected[0] = torch.arange(k, device=device, dtype=torch.int32) + expected[1, : k // 2] = offsets[1] + torch.arange( + k // 2, device=device, dtype=torch.int32 + ) + + assert torch.equal(output, expected) + + +@pytest.mark.parametrize("transform_mode", ["page_table", "ragged"]) +@pytest.mark.parametrize(("pattern", "k"), [("all_equal", 8), ("pivot_tie", 6)]) +def test_top_k_transform_filtered_deterministic_valid_selection_under_ties( + transform_mode, pattern, k, set_topk_algo +): + """Filtered deterministic transform APIs should be repeatable and select a valid top-k set.""" if not can_implement_filtered_topk(): pytest.skip("Filtered top-k not supported on this device") set_topk_algo("filtered") - logits, _, _, k = _build_fp32_long_seq_pivot_mismatch_inputs() + device = "cuda" + vocab_size = 16384 + logits = torch.ones((1, vocab_size), device=device, dtype=torch.float16) + gt_count = 0 - output = _run_transform_with_identity_mapping(logits, k, transform_mode) - ref_indices = torch.topk(logits, k, dim=-1, sorted=True).indices.to(torch.int32) - _assert_unordered_indices_match(output, ref_indices) + if pattern == "all_equal": + pass + else: + gt_count = 2 + logits[:, vocab_size - gt_count :] = 2.0 + + output_a = _run_transform_with_identity_mapping( + logits, k, transform_mode, deterministic=True + ) + output_b = _run_transform_with_identity_mapping( + logits, k, transform_mode, deterministic=True + ) + _assert_repeatable_valid_identity_transform_selection( + output_a, output_b, vocab_size, k, gt_count=gt_count + ) + + +@pytest.mark.parametrize("transform_mode", ["page_table", "ragged"]) +def test_top_k_transform_deterministic_repeatability_multi_cta_all_equal( + transform_mode, set_topk_algo +): + """Force radix multi-CTA all-equal transform path where later CTAs take zero eq quota.""" + set_topk_algo("multi_cta") + + device = "cuda" + vocab_size = 131072 + k = 256 + logits = torch.ones((1, vocab_size), device=device, dtype=torch.float16) + + output_a = _run_transform_with_identity_mapping( + logits, k, transform_mode, deterministic=True + ) + output_b = _run_transform_with_identity_mapping( + logits, k, transform_mode, deterministic=True + ) + _assert_repeatable_valid_identity_transform_selection( + output_a, output_b, vocab_size, k + ) + + +@pytest.mark.parametrize("algo", ["auto", "filtered", "multi_cta"]) +@pytest.mark.parametrize("pattern", ["tie_heavy", "pivot_tie"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_top_k_deterministic_repeatability_tie_cases_by_algo( + algo, pattern, dtype, set_topk_algo +): + """Deterministic top-k should be repeatable and valid under tie pressure.""" + if dtype == torch.bfloat16: + _require_sm80_for_bf16() + + if algo == "filtered" and not can_implement_filtered_topk(): + pytest.skip("Filtered top-k not supported on this device") + + set_topk_algo(algo) + batch_size = 4 + vocab_size = 16384 + k = 256 + num_runs = 20 + device = "cuda" + gt_count = 0 + + if pattern == "tie_heavy": + base = ( + torch.arange(vocab_size, device=device, dtype=torch.float32) % 32 + ) / 32.0 + logits = base.unsqueeze(0).repeat(batch_size, 1).to(dtype).contiguous() + else: # pivot_tie + logits = torch.ones(batch_size, vocab_size, device=device, dtype=dtype) + gt_count = max(1, min(k // 4, vocab_size // 8)) + logits[:, vocab_size - gt_count :] = 2.0 + + ref_values, ref_indices = flashinfer.top_k( + logits, k, deterministic=True, sorted=False + ) + for _ in range(num_runs - 1): + values, indices = flashinfer.top_k(logits, k, deterministic=True, sorted=False) + _assert_repeatable_valid_topk_selection( + logits, ref_values, ref_indices, values, indices, k, gt_count=gt_count + ) + + +@pytest.mark.parametrize("algo", ["filtered", "multi_cta"]) +@pytest.mark.parametrize("transform_mode", ["page_table", "ragged"]) +def test_top_k_transform_deterministic_repeatability_tie_heavy_by_algo( + algo, transform_mode, set_topk_algo +): + """Deterministic transform APIs should be repeatable under tie-heavy input.""" + _require_sm80_for_bf16() + + if algo == "filtered" and not can_implement_filtered_topk(): + pytest.skip("Filtered top-k not supported on this device") + + set_topk_algo(algo) + num_rows = 4 + max_len = 16384 + k = 256 + num_runs = 20 + device = "cuda" + + base = (torch.arange(max_len, device=device, dtype=torch.float32) % 32) / 32.0 + scores = base.unsqueeze(0).repeat(num_rows, 1).to(torch.bfloat16).contiguous() + lengths = torch.full((num_rows,), max_len, device=device, dtype=torch.int32) + offsets = None + if transform_mode == "ragged": + offsets = torch.arange( + 0, num_rows * max_len, max_len, device=device, dtype=torch.int32 + ) + + _assert_repeatable_transform_output( + scores, + k, + transform_mode, + num_runs=num_runs, + deterministic=True, + lengths=lengths, + offsets=offsets, + ) + + +def test_top_k_deterministic_bitwise_repeatability(): + """Deterministic top-k should be bitwise identical across repeated runs.""" + batch_size = 8 + vocab_size = 32768 + k = 512 + num_runs = 50 + device = "cuda" + + # Tie-heavy logits: repeated value buckets to stress tie handling. + pattern = (torch.arange(vocab_size, device=device, dtype=torch.float32) % 64) / 64.0 + logits = pattern.unsqueeze(0).repeat(batch_size, 1).contiguous() + + ref_values, ref_indices = flashinfer.top_k( + logits, k, deterministic=True, sorted=False + ) + for _ in range(num_runs - 1): + values, indices = flashinfer.top_k(logits, k, deterministic=True, sorted=False) + assert torch.equal(values, ref_values) + assert torch.equal(indices, ref_indices) + + +@pytest.mark.parametrize("transform_mode", ["page_table", "ragged"]) +def test_top_k_transform_deterministic_repeatability(transform_mode): + """Deterministic transform APIs should be bitwise identical across runs.""" + num_rows = 8 + max_len = 8192 + k = 512 + num_runs = 30 + device = "cuda" + + pattern = (torch.arange(max_len, device=device, dtype=torch.float32) % 32) / 32.0 + scores = pattern.unsqueeze(0).repeat(num_rows, 1).contiguous() + lengths = torch.full((num_rows,), max_len, device=device, dtype=torch.int32) + src_page_table = None + offsets = None + if transform_mode == "ragged": + offsets = torch.arange( + 0, num_rows * max_len, max_len, device=device, dtype=torch.int32 + ) + + _assert_repeatable_transform_output( + scores, + k, + transform_mode, + num_runs=num_runs, + deterministic=True, + lengths=lengths, + src_page_table=src_page_table, + offsets=offsets, + ) + + +@pytest.mark.parametrize("transform_mode", ["page_table", "ragged"]) +def test_top_k_transform_deterministic_k1_remap(transform_mode): + """Deterministic transform APIs must remap local top-1 positions correctly.""" + num_rows = 4 + max_len = 257 + device = "cuda" + + torch.manual_seed(0) + scores = torch.randn(num_rows, max_len, device=device, dtype=torch.float32) + lengths = torch.full((num_rows,), max_len, device=device, dtype=torch.int32) + ref_idx = torch.topk(scores, 1, dim=-1).indices.to(torch.int32) + src_page_table = None + offsets = None + + if transform_mode == "page_table": + src_page_table = ( + torch.arange(max_len, device=device, dtype=torch.int32) + .unsqueeze(0) + .repeat(num_rows, 1) + .mul(3) + .add(7) + .contiguous() + ) + ref = torch.gather(src_page_table, 1, ref_idx) + else: + offsets = torch.tensor([5, 1000, 2000, 3000], device=device, dtype=torch.int32) + ref = ref_idx + offsets.unsqueeze(1) + + out = _run_transform( + scores, + 1, + transform_mode, + lengths=lengths, + deterministic=True, + src_page_table=src_page_table, + offsets=offsets, + ) + assert torch.equal(out, ref) if __name__ == "__main__": # Basic tests test_top_k(4, 32000, 256, torch.float32) test_top_k_sorted(4, 32000, 256, torch.float32) - test_top_k_large_batch(64, 128512, 256) + test_top_k_large_batch(64, 128512, 256, False) # Fused transform tests print("Testing page table transform...") From f90300ae4276b40f69dc0c9736599b787441b139 Mon Sep 17 00:00:00 2001 From: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Date: Wed, 1 Apr 2026 13:42:51 +0200 Subject: [PATCH 2/2] fix: topK uint32 overflow Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> --- include/flashinfer/topk.cuh | 10 +++++----- tests/utils/test_topk.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 62071097b1..79c8eda304 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -1154,7 +1154,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( if (chunk_start + i < k) { row_output[chunk_start + i] = static_cast(chunk_start + i); output_values[row_idx * top_k_val + chunk_start + i] = - input[row_idx * stride + chunk_start + i]; + input[static_cast(row_idx) * stride + chunk_start + i]; } } // Clear histogram for next iteration (in case it's k < length) @@ -1217,9 +1217,9 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( uint32_t cta_local_eq_count = 0; OrderedType ordered_pivot = RadixSelectFindPivot( - input + row_idx * stride, shared_ordered, local_histogram, suffix_sum, shared_scalars, - state, chunk_start, actual_chunk_size, k, barrier_phase, ctas_per_group, cta_in_group, - tx, iter, cta_local_gt_count, cta_local_eq_count); + input + static_cast(row_idx) * stride, shared_ordered, local_histogram, + suffix_sum, shared_scalars, state, chunk_start, actual_chunk_size, k, barrier_phase, + ctas_per_group, cta_in_group, tx, iter, cta_local_gt_count, cta_local_eq_count); auto collect_indices = [&](auto&& output_func) { if constexpr (DETERMINISTIC) { @@ -2284,7 +2284,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) if (bid >= num_rows) return; const int length = (lengths != nullptr) ? lengths[bid] : static_cast(max_len); - const DType* score = input + bid * max_len; + const DType* score = input + static_cast(bid) * max_len; IdType* dst = output + bid * top_k; // Mode-specific setup diff --git a/tests/utils/test_topk.py b/tests/utils/test_topk.py index ae622b9e7a..0f931e76f2 100644 --- a/tests/utils/test_topk.py +++ b/tests/utils/test_topk.py @@ -2225,6 +2225,43 @@ def test_top_k_transform_deterministic_k1_remap(transform_mode): assert torch.equal(out, ref) +def test_top_k_uint32_pointer_overflow(): + """Test top_k with batch*vocab > 2^32 bytes""" + batch_size = 32769 + vocab_size = 131072 + k = 256 + + required_bytes = batch_size * vocab_size * 2 # fp16 + free_mem = torch.cuda.mem_get_info("cuda")[0] + if free_mem < int(required_bytes * 1.15): + pytest.skip( + f"Insufficient GPU memory: {free_mem / 1e9:.1f}GB free, " + f"need ~{required_bytes / 1e9:.1f}GB" + ) + + torch.manual_seed(42) + logits = torch.randn(batch_size, vocab_size, device="cuda", dtype=torch.float16) + + values, indices = flashinfer.top_k(logits, k) + + assert values.shape == (batch_size, k) + assert indices.shape == (batch_size, k) + + # Only check the last row: its element offset (row_idx * vocab_size) + # exceeds 2^32, so a uint32 overflow bug would corrupt this region. + row_idx = batch_size - 1 + gathered = torch.gather( + logits[row_idx : row_idx + 1], -1, indices[row_idx : row_idx + 1] + ) + torch.testing.assert_close(values[row_idx : row_idx + 1], gathered) + + _, ref_indices = torch.topk(logits[row_idx : row_idx + 1], k, dim=-1) + accuracy = compute_topk_accuracy( + indices[row_idx : row_idx + 1].int(), ref_indices.int(), 1, k + ) + assert accuracy >= 0.98, f"Last row accuracy {accuracy:.4f} < 0.98" + + if __name__ == "__main__": # Basic tests test_top_k(4, 32000, 256, torch.float32)