diff --git a/benchmarks/benchmark_topk_topp.py b/benchmarks/benchmark_topk_topp.py new file mode 100644 index 000000000000..dae0458d01ee --- /dev/null +++ b/benchmarks/benchmark_topk_topp.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark comparing Triton vs PyTorch sort-based top-k/top-p implementations. + +Compares: +- apply_top_k_top_p_triton (Triton binary search) +- apply_top_k_top_p (PyTorch sort-based) + +Scenarios: +- top_k only (whole batch, partial batch) +- top_p only (whole batch, partial batch) +- mix of top_k and top_p +""" + +import argparse +import gc +from dataclasses import dataclass + +import torch + +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch +from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + +@dataclass +class BenchmarkConfig: + """Configuration for a benchmark run.""" + + name: str + batch_size: int + vocab_size: int + # k and p can be tensors or None + k_values: torch.Tensor | None # [batch_size] or None + p_values: torch.Tensor | None # [batch_size] or None + description: str + ops_pct: float = 0.0 # Percentage of ops relative to batch size + + +def calculate_ops_pct( + k_values: torch.Tensor | None, + p_values: torch.Tensor | None, + vocab_size: int, + batch_size: int, +) -> float: + """ + Calculate the percentage of active top-k and top-p operations. + + Returns percentage where 100% = batch_size ops. + E.g., if all rows have both top-k and top-p active, returns 200%. + """ + active_ops = 0 + + if k_values is not None: + # Count rows where k < vocab_size (active top-k filtering) + active_ops += (k_values < vocab_size).sum().item() + + if p_values is not None: + # Count rows where p < 1.0 (active top-p filtering) + active_ops += (p_values < 1.0).sum().item() + + return (active_ops / batch_size) * 100 if batch_size > 0 else 0.0 + + +def create_logits( + batch_size: int, vocab_size: int, device: str = "cuda" +) -> torch.Tensor: + """Create random logits tensor.""" + return torch.randn(batch_size, vocab_size, dtype=torch.float32, device=device) + + +def measure_memory() -> tuple[int, int]: + """Return (allocated, reserved) memory in bytes.""" + torch.cuda.synchronize() + return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated() + + +def reset_memory_stats(): + """Reset peak memory statistics.""" + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + gc.collect() + + +def benchmark_function( + func, + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + warmup_iters: int = 5, + benchmark_iters: int = 20, +) -> tuple[float, int]: + """ + Benchmark a function and return (avg_time_ms, peak_memory_bytes). + + Returns average time in milliseconds and peak memory usage. + """ + # Warmup + for _ in range(warmup_iters): + logits_copy = logits.clone() + func(logits_copy, k, p) + torch.cuda.synchronize() + + # Reset memory stats before benchmark + reset_memory_stats() + + # Benchmark + start_events = [ + torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters) + ] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)] + + for i in range(benchmark_iters): + logits_copy = logits.clone() + start_events[i].record() + func(logits_copy, k, p) + end_events[i].record() + + torch.cuda.synchronize() + + # Calculate timing + times = [ + start_events[i].elapsed_time(end_events[i]) for i in range(benchmark_iters) + ] + avg_time = sum(times) / len(times) + + # Get peak memory + _, peak_memory = measure_memory() + + return avg_time, peak_memory + + +def create_benchmark_configs( + batch_sizes: list[int], + vocab_sizes: list[int], + device: str = "cuda", +) -> list[BenchmarkConfig]: + """Create all benchmark configurations.""" + configs = [] + + for vocab_size in vocab_sizes: + for batch_size in batch_sizes: + # 1. Top-k only - whole batch (all rows have k < vocab_size) + k_all = torch.full((batch_size,), 50, dtype=torch.int32, device=device) + configs.append( + BenchmarkConfig( + name=f"topk_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_all, + p_values=None, + description=f"Top-k only (whole batch, k=50), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_all, None, vocab_size, batch_size), + ) + ) + + # 2. Top-k only - partial batch (half have k=50, half have k=vocab_size) + k_partial = torch.full((batch_size,), 50, dtype=torch.int32, device=device) + k_partial[batch_size // 2 :] = vocab_size # No filtering for second half + configs.append( + BenchmarkConfig( + name=f"topk_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_partial, + p_values=None, + description=f"Top-k only (partial batch, 50% k=50, 50% k=vocab), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_partial, None, vocab_size, batch_size), + ) + ) + + # 3. Top-p only - whole batch (all rows have p < 1.0) + p_all = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device) + configs.append( + BenchmarkConfig( + name=f"topp_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=None, + p_values=p_all, + description=f"Top-p only (whole batch, p=0.9), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(None, p_all, vocab_size, batch_size), + ) + ) + + # 4. Top-p only - partial batch (half have p=0.9, half have p=1.0) + p_partial = torch.full( + (batch_size,), 0.9, dtype=torch.float32, device=device + ) + p_partial[batch_size // 2 :] = 1.0 # No filtering for second half + configs.append( + BenchmarkConfig( + name=f"topp_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=None, + p_values=p_partial, + description=f"Top-p only (partial batch, 50% p=0.9, 50% p=1.0), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(None, p_partial, vocab_size, batch_size), + ) + ) + + # 5. Mix of top-k and top-p (both applied to whole batch) + k_mix = torch.full((batch_size,), 100, dtype=torch.int32, device=device) + p_mix = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device) + configs.append( + BenchmarkConfig( + name=f"topk_topp_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_mix, + p_values=p_mix, + description=f"Top-k + Top-p (whole batch, k=100, p=0.9), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_mix, p_mix, vocab_size, batch_size), + ) + ) + + # 6. Mix with partial application (some rows k only, some p only, some both) + k_mixed = torch.full( + (batch_size,), vocab_size, dtype=torch.int32, device=device + ) + p_mixed = torch.full((batch_size,), 1.0, dtype=torch.float32, device=device) + # First third: k only + third = batch_size // 3 + k_mixed[:third] = 50 + # Second third: p only + p_mixed[third : 2 * third] = 0.9 + # Last third: both k and p + k_mixed[2 * third :] = 100 + p_mixed[2 * third :] = 0.9 + configs.append( + BenchmarkConfig( + name=f"mixed_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_mixed, + p_values=p_mixed, + description=f"Mixed partial (1/3 k=50, 1/3 p=0.9, 1/3 both), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_mixed, p_mixed, vocab_size, batch_size), + ) + ) + + return configs + + +def format_memory(bytes_val: int) -> str: + """Format memory in human-readable form.""" + if bytes_val >= 1024**3: + return f"{bytes_val / (1024**3):.2f} GB" + elif bytes_val >= 1024**2: + return f"{bytes_val / (1024**2):.2f} MB" + elif bytes_val >= 1024: + return f"{bytes_val / 1024:.2f} KB" + return f"{bytes_val} B" + + +def run_benchmark( + configs: list[BenchmarkConfig], + warmup_iters: int = 5, + benchmark_iters: int = 20, + verbose: bool = True, +): + """Run all benchmarks and print results.""" + results = [] + + print("=" * 100) + print("Top-k/Top-p Benchmark: Triton vs PyTorch Sort-based") + print("=" * 100) + print() + + for config in configs: + if verbose: + print(f"Running: {config.description}") + + # Create fresh logits for this config + logits = create_logits(config.batch_size, config.vocab_size) + + # Benchmark Triton + reset_memory_stats() + triton_time, triton_mem = benchmark_function( + apply_top_k_top_p_triton, + logits, + config.k_values, + config.p_values, + warmup_iters, + benchmark_iters, + ) + + # Benchmark PyTorch + reset_memory_stats() + pytorch_time, pytorch_mem = benchmark_function( + apply_top_k_top_p_pytorch, + logits, + config.k_values, + config.p_values, + warmup_iters, + benchmark_iters, + ) + + speedup = pytorch_time / triton_time if triton_time > 0 else float("inf") + mem_ratio = pytorch_mem / triton_mem if triton_mem > 0 else float("inf") + + result = { + "config": config, + "triton_time_ms": triton_time, + "pytorch_time_ms": pytorch_time, + "triton_mem": triton_mem, + "pytorch_mem": pytorch_mem, + "speedup": speedup, + "mem_ratio": mem_ratio, + } + results.append(result) + + if verbose: + print(f" Triton: {triton_time:.3f} ms, {format_memory(triton_mem)}") + print(f" PyTorch: {pytorch_time:.3f} ms, {format_memory(pytorch_mem)}") + print(f" Speedup: {speedup:.2f}x, Memory ratio: {mem_ratio:.2f}x") + print() + + # Clean up + del logits + reset_memory_stats() + + return results + + +def print_summary_table(results: list[dict]): + """Print a summary table of results.""" + print() + print("=" * 130) + print("SUMMARY TABLE") + print("=" * 130) + print() + + # Header + header = ( + f"{'Scenario':<40} {'Batch':>6} {'Vocab':>7} {'Ops%':>6} " + f"{'Triton (ms)':>12} {'PyTorch (ms)':>13} {'Speedup':>8} " + f"{'Tri Mem':>10} {'Pyt Mem':>10}" + ) + print(header) + print("-" * 130) + + # Group by scenario type + current_vocab = None + for result in results: + config = result["config"] + + # Add separator between vocab sizes + if current_vocab != config.vocab_size: + if current_vocab is not None: + print("-" * 130) + current_vocab = config.vocab_size + + scenario = config.name.split("_b")[0] # Extract scenario name + print( + f"{scenario:<40} {config.batch_size:>6} {config.vocab_size:>7} " + f"{config.ops_pct:>5.0f}% " + f"{result['triton_time_ms']:>12.3f} {result['pytorch_time_ms']:>13.3f} " + f"{result['speedup']:>7.2f}x " + f"{format_memory(result['triton_mem']):>10} " + f"{format_memory(result['pytorch_mem']):>10}" + ) + + print("=" * 130) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark Triton vs PyTorch sort-based top-k/top-p implementations" + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=[1, 4, 16, 24, 32, 48, 56, 64, 96, 128, 192, 256, 512, 1024], + help="Batch sizes to test (default: 1 4 16 64)", + ) + parser.add_argument( + "--vocab-sizes", + type=int, + nargs="+", + default=[32768, 131072], # 32k, 128k + help="Vocabulary sizes to test (default: 32768 131072)", + ) + parser.add_argument( + "--warmup-iters", + type=int, + default=5, + help="Number of warmup iterations (default: 5)", + ) + parser.add_argument( + "--benchmark-iters", + type=int, + default=20, + help="Number of benchmark iterations (default: 20)", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Only print summary table", + ) + + args = parser.parse_args() + + # Print configuration + print(f"Batch sizes: {args.batch_sizes}") + print(f"Vocab sizes: {args.vocab_sizes}") + print(f"Warmup iterations: {args.warmup_iters}") + print(f"Benchmark iterations: {args.benchmark_iters}") + print() + + # Check CUDA + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available. This benchmark requires a GPU.") + return + + device_name = torch.cuda.get_device_name(0) + print(f"GPU: {device_name}") + print() + + # Create configs + configs = create_benchmark_configs( + args.batch_sizes, + args.vocab_sizes, + ) + + # Run benchmarks + results = run_benchmark( + configs, + warmup_iters=args.warmup_iters, + benchmark_iters=args.benchmark_iters, + verbose=not args.quiet, + ) + + # Print summary + print_summary_table(results) + + +if __name__ == "__main__": + main() diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index a61f5af423fa..00af7d198144 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -5,8 +5,9 @@ from torch import Generator from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch +CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None DEVICE = current_platform.device_type BATCH_SIZE = 1024 @@ -39,11 +40,11 @@ def test_topk_impl_equivalence(): ) # Top-k only implementation - result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + result1 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=None) # Top-p + top-k no_op_top_p = torch.tensor([1.0]) - result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) + result2 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=no_op_top_p) assert torch.allclose(result1, result2) @@ -94,7 +95,7 @@ def test_flashinfer_sampler(): torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0 ) - python_logits = apply_top_k_top_p( + python_logits = apply_top_k_top_p_pytorch( logits=logits.clone(), k=k_values, p=p_values, @@ -116,3 +117,178 @@ def test_flashinfer_sampler(): assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), ( "FlashInfer and Python sampling implementations do not match!" ) + + +# ============================================================================= +# Triton kernel tests +# ============================================================================= + + +@pytest.mark.skipif(CUDA_DEVICE is None, reason="CUDA not available") +class TestTritonTopkTopp: + """Tests for the Triton top-k/top-p kernel.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Set up test fixtures.""" + torch.set_default_device(CUDA_DEVICE) + self.generator = Generator(device=CUDA_DEVICE).manual_seed(42) + + def _compare_results( + self, + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + ): + """Compare Triton kernel results with PyTorch sorting implementation. + + For top-k only, we expect exact match. + For top-p (with or without top-k), we allow small differences due to + floating-point precision in probability sum calculations. + """ + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + # Clone logits for both implementations + logits_pytorch = logits.clone() + logits_triton = logits.clone().to(torch.float32) + + # Apply PyTorch sorting implementation + result_pytorch = apply_top_k_top_p_pytorch(logits_pytorch, k, p) + + # Apply Triton kernel + k_i32 = k.to(torch.int32) if k is not None else None + p_f32 = p.to(torch.float32) if p is not None else None + result_triton = apply_top_k_top_p_triton(logits_triton, k_i32, p_f32) + + # Compare kept counts per row + pytorch_kept = (result_pytorch != float("-inf")).sum(dim=-1) + triton_kept = (result_triton != float("-inf")).sum(dim=-1) + + if p is None: + # Top-k only: expect exact match + assert torch.equal(pytorch_kept, triton_kept), ( + f"Top-k mask mismatch: PyTorch kept {pytorch_kept.tolist()}, " + f"Triton kept {triton_kept.tolist()}" + ) + else: + # Top-p involved: allow small differences + # Either < 1% of kept values OR < 5 values absolute + max_diff = (pytorch_kept - triton_kept).abs().max().item() + max_kept = pytorch_kept.max().item() + if max_kept > 0 and max_diff > 3: + diff_pct = max_diff / max_kept * 100 + assert diff_pct < 0.5, ( + f"Top-p mask difference too large: {diff_pct:.2f}% " + f"(max diff {max_diff} values out of {max_kept})" + ) + + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + def test_topk_only(self, batch_size: int, vocab_size: int): + """Test top-k only (p=None).""" + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + k = torch.randint( + 1, min(100, vocab_size), (batch_size,), generator=self.generator + ) + # Randomly disable top-k for some rows (~25%) + disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + k.masked_fill_(disable_mask, vocab_size) + + self._compare_results(logits, k, p=None) + + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + def test_topp_only(self, batch_size: int, vocab_size: int): + """Test top-p only (k=None).""" + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0] + # Randomly disable top-p for some rows (~25%) + disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + p.masked_fill_(disable_mask, 1.0) + + self._compare_results(logits, k=None, p=p) + + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + def test_topk_and_topp(self, batch_size: int, vocab_size: int): + """Test combined top-k and top-p.""" + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + k = torch.randint( + 1, min(100, vocab_size), (batch_size,), generator=self.generator + ) + p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0] + + # Randomly disable top-k for some rows (~25%) + disable_k = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + k.masked_fill_(disable_k, vocab_size) + # Randomly disable top-p for some rows (~25%) + disable_p = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + p.masked_fill_(disable_p, 1.0) + + self._compare_results(logits, k, p) + + def test_both_disabled(self): + """Test when both k and p are None (should be no-op).""" + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32) + logits_clone = logits.clone() + + result = apply_top_k_top_p_triton(logits_clone, k=None, p=None) + + assert torch.equal(result, logits), "Should be no-op when both k and p are None" + + def test_extreme_k_values(self): + """Test edge cases for k values.""" + batch_size, vocab_size = 16, 1024 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + + # k=1 (keep only top 1) + k = torch.ones(batch_size, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + # k=vocab_size (keep all) + k = torch.full((batch_size,), vocab_size, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + # Mixed extreme values + k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + def test_extreme_p_values(self): + """Test edge cases for p values.""" + batch_size, vocab_size = 16, 1024 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + + # p close to 0 (very restrictive) + p = torch.full((batch_size,), 0.01, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + # p=1.0 (keep all) + p = torch.ones(batch_size, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + # Mixed values + p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + def test_large_batch(self): + """Test with a large batch size.""" + batch_size, vocab_size = 512, 32000 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + k = torch.randint(1, 50, (batch_size,), generator=self.generator) + p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5 + + self._compare_results(logits, k, p) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 03da3e565e49..eaf6dcc3c520 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -11,6 +11,10 @@ from vllm.config.model import LogprobsMode from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton logger = init_logger(__name__) @@ -87,8 +91,6 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: else: self.forward = self.forward_native - self.apply_top_k_top_p = apply_top_k_top_p - def forward_native( self, logits: torch.Tensor, @@ -101,7 +103,7 @@ def forward_native( The logits tensor may be updated in-place. """ - logits = self.apply_top_k_top_p(logits, k, p) + logits = apply_top_k_top_p(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": logits_to_return = logits @@ -149,7 +151,7 @@ def forward_cpu( The logits tensor may be updated in-place. """ - logits = self.apply_top_k_top_p(logits, k, p) + logits = apply_top_k_top_p_pytorch(logits, k, p, allow_cpu_sync=True) logits_to_return = None if self.logprobs_mode == "processed_logits": logits_to_return = logits @@ -158,14 +160,14 @@ def forward_cpu( if len(generators) != logits.shape[0]: return compiled_random_sample(logits), logits_to_return - else: - probs = logits.softmax(dim=-1, dtype=torch.float32) - q = torch.empty_like(probs) - q.exponential_() - for i, generator in generators.items(): - q[i].exponential_(generator=generator) - return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return + probs = logits.softmax(dim=-1, dtype=torch.float32) + q = torch.empty_like(probs) + q.exponential_() + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + + return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return def forward_hip( self, @@ -241,9 +243,28 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: def apply_top_k_top_p( + logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None +) -> torch.Tensor: + if p is None and k is None: + return logits + + # Rough empirical heuristic + if HAS_TRITON: + batch_size, vocab_size = logits.shape + both_k_and_p = p is not None and k is not None + threshold = vocab_size // (1024 if both_k_and_p else 2048) + if batch_size >= threshold: + # Use pytorch sort implementation for smaller batch sizes. + return apply_top_k_top_p_triton(logits, k, p) + + return apply_top_k_top_p_pytorch(logits, k, p) + + +def apply_top_k_top_p_pytorch( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, + allow_cpu_sync: bool = False, ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. @@ -256,8 +277,9 @@ def apply_top_k_top_p( if k is None: return logits - # Avoid sorting vocab for top-k only case. - return apply_top_k_only(logits, k) + if allow_cpu_sync: + # Avoid sorting vocab for top-k only case. + return apply_top_k_only(logits, k) logits_sort, logits_idx = logits.sort(dim=-1, descending=False) @@ -279,18 +301,16 @@ def apply_top_k_top_p( logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits + return logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) -def apply_top_k_only( - logits: torch.Tensor, - k: torch.Tensor, -) -> torch.Tensor: +def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor: """ Apply top-k mask to the logits. This implementation doesn't involve sorting the entire vocab. + Note however that it involves a GPU->CPU sync which can be detrimental for + async scheduling performance. The logits tensor may be updated in-place. """ @@ -304,8 +324,7 @@ def apply_top_k_only( top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) - logits.masked_fill_(logits < top_k_mask, -float("inf")) - return logits + return logits.masked_fill_(logits < top_k_mask, -float("inf")) def random_sample( diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py new file mode 100644 index 000000000000..ac759944bc80 --- /dev/null +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -0,0 +1,383 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Combined Top-K and Top-P Triton kernels. + +These kernels apply top-k filtering first, then top-p on the remaining values. +This is more efficient than sorting the entire vocabulary. + +Algorithm: +1. Find k-th largest logit using binary search → top-k threshold +2. Mask logits below threshold, compute softmax (only k values contribute) +3. Find probability threshold for top-p using binary search +4. Apply final mask + +Complexity: O(vocab_size * (k_iters + p_iters)) where iters ≈ 16-20 +""" + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _topk_topp_kernel( + # Input/output + logits_ptr, + # Parameters per row + k_ptr, + p_ptr, + # Dimensions + logits_stride: tl.constexpr, + vocab_size: tl.constexpr, + # Mask value + mask_value: tl.constexpr, + # Algorithm parameters + BLOCK_SIZE: tl.constexpr, + K_ITERS: tl.constexpr, + P_ITERS: tl.constexpr, + # Feature flags (when False, use default values instead of loading) + TOPK_ENABLED: tl.constexpr, + TOPP_ENABLED: tl.constexpr, +): + """ + Combined top-k and top-p masking kernel. + + Applies top-k first (by logit value), then top-p (by probability). + Optimized to skip softmax computation when p >= 1.0. + """ + row_idx = tl.program_id(0) + + if TOPK_ENABLED: + k = tl.load(k_ptr + row_idx) + apply_topk = k < vocab_size + else: + # Default: keep all (no top-k filtering) + k = vocab_size + apply_topk = False + + if TOPP_ENABLED: + p = tl.load(p_ptr + row_idx) + apply_topp = p < 1.0 + else: + # Default: keep all (no top-p filtering) + p = 1.0 + apply_topp = False + + # Early exit if nothing to do + if (not apply_topk) and (not apply_topp): + return + + row_ptr = logits_ptr + row_idx * logits_stride + + # ========================================================================= + # Phase 1: Find top-k threshold using binary search on logits + # OPTIMIZATION: Fuse min/max finding with first binary search iteration + # by counting values > 0 during min/max pass (saves 1 memory pass) + # ========================================================================= + + topk_threshold = float("-inf") + + if apply_topk: + # Fused pass: find min/max AND count values > 0 (first binary search step) + max_logit = float("-inf") + min_logit = float("inf") + count_above_zero = tl.zeros([1], dtype=tl.int32) + + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + max_logit = tl.maximum(max_logit, tl.max(vals)) + vals_min = tl.where(mask, vals, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(vals_min)) + # Count values > 0 (fused first binary search iteration) + count_above_zero += tl.sum((vals > 0.0).to(tl.int32)) + + # Use count_above_zero to set initial bounds (equivalent to first iteration) + # If count_above_zero >= k, the k-th largest is > 0, so raise lo to 0 + # Otherwise, the k-th largest is <= 0, so lower hi to 0 + if tl.sum(count_above_zero) >= k: + lo = 0.0 + hi = max_logit + else: + lo = min_logit + hi = 0.0 + + # Continue with remaining K_ITERS-1 binary search iterations + for _ in range(K_ITERS - 1): + mid = (lo + hi) * 0.5 + count_gt = tl.zeros([1], dtype=tl.int32) + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + count_gt += tl.sum((vals > mid).to(tl.int32)) + if tl.sum(count_gt) >= k: + lo = mid + else: + hi = mid + + # Refine to exact k-th largest value. + # After binary search: lo < k-th value <= hi (approximately). + # Find the actual logit values at these boundaries. + count_gt_lo = tl.zeros([1], dtype=tl.int32) + min_above_lo = float("inf") + max_at_or_below_hi = float("-inf") + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + count_gt_lo += tl.sum((vals > lo).to(tl.int32)) + vals_above_lo = tl.where(vals > lo, vals, float("inf")) + min_above_lo = tl.minimum(min_above_lo, tl.min(vals_above_lo)) + vals_at_or_below_hi = tl.where(vals <= hi, vals, float("-inf")) + max_at_or_below_hi = tl.maximum( + max_at_or_below_hi, tl.max(vals_at_or_below_hi) + ) + + if tl.sum(count_gt_lo) == k: + topk_threshold = min_above_lo + else: + topk_threshold = max_at_or_below_hi + + # ========================================================================= + # If no top-p, apply top-k mask and return early + # ========================================================================= + + if not apply_topp: + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + result = tl.where(vals >= topk_threshold, vals, mask_value) + tl.store(row_ptr + offs, result, mask=mask) + return + + # ========================================================================= + # Phase 2: Compute softmax using online softmax (single pass) + # ========================================================================= + # Online softmax computes max and exp_sum in one pass by rescaling + # the running sum when a new max is found. + # + # Key insight: We need to handle the case where softmax_max is -inf + # (no valid values seen yet). In this case, -inf - (-inf) = nan, + # so we must skip blocks with no valid values. + + softmax_max = float("-inf") + exp_sum = 0.0 + + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + # Apply top-k mask + vals = tl.where(vals >= topk_threshold, vals, float("-inf")) + + # Find block max + block_max = tl.max(vals) + + # Skip blocks with no valid values (all -inf) + # This avoids nan from -inf - (-inf) + if block_max > float("-inf"): + # Update running max and rescale sum if needed + new_max = tl.maximum(softmax_max, block_max) + + # Rescale previous sum: sum * exp(old_max - new_max) + # When softmax_max is -inf (first valid block), exp(-inf - finite) = 0, + # which is correct since exp_sum starts at 0. + exp_sum = exp_sum * tl.exp(softmax_max - new_max) + softmax_max = new_max + + # Add current block's contribution (normalized by new max) + exp_sum += tl.sum(tl.exp(vals - softmax_max)) + + log_exp_sum = tl.log(exp_sum) + + # ========================================================================= + # Phase 3: Find top-p threshold using binary search on probabilities + # OPTIMIZATION: Fuse min/max finding with first binary search iteration + # by computing prob mass > 0.5 during min/max pass (saves 1 memory pass) + # ========================================================================= + + # Fused pass: find min/max log-probs AND sum probs > 0.5 (first iteration) + max_log_prob = float("-inf") + min_log_prob = float("inf") + log_half = -0.6931471805599453 # log(0.5) + prob_sum_above_half = 0.0 + + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + + # Only consider top-k values + is_topk = vals >= topk_threshold + + # log_prob = logit - softmax_max - log(exp_sum) + log_probs = vals - softmax_max - log_exp_sum + + log_probs_masked = tl.where(is_topk, log_probs, float("-inf")) + max_log_prob = tl.maximum(max_log_prob, tl.max(log_probs_masked)) + + log_probs_for_min = tl.where(is_topk & mask, log_probs, float("inf")) + min_log_prob = tl.minimum(min_log_prob, tl.min(log_probs_for_min)) + + # Sum probability mass above 0.5 (fused first binary search iteration) + probs = tl.exp(log_probs) + above_half = (log_probs > log_half) & is_topk + prob_sum_above_half += tl.sum(tl.where(above_half, probs, 0.0)) + + # Use prob_sum_above_half to set initial bounds (equivalent to first iteration) + if prob_sum_above_half >= p: + lo_lp = log_half + hi_lp = max_log_prob + else: + lo_lp = min_log_prob + hi_lp = log_half + + # Continue with remaining P_ITERS-1 binary search iterations + for _ in range(P_ITERS - 1): + mid_lp = (lo_lp + hi_lp) * 0.5 + + # Sum probabilities strictly > mid_lp + prob_sum_gt = 0.0 + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + + is_topk = vals >= topk_threshold + log_probs = vals - softmax_max - log_exp_sum + probs = tl.exp(log_probs) + + # Only sum probs that are strictly > threshold and in top-k + above_threshold = (log_probs > mid_lp) & is_topk + prob_sum_gt += tl.sum(tl.where(above_threshold, probs, 0.0)) + + # If sum of probs strictly above mid >= p, raise threshold + if prob_sum_gt >= p: + lo_lp = mid_lp + else: + hi_lp = mid_lp + + # Refine to exact threshold using combined approach (same as top-k). + # After binary search: prob_sum(> lo_lp) >= p, prob_sum(> hi_lp) < p. + # Count how many distinct log-probs are > lo_lp to determine which refinement. + count_gt_lo_lp = tl.zeros([1], dtype=tl.int32) + min_lp_above_lo = float("inf") + max_lp_at_or_below_hi = float("-inf") + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + + is_topk = vals >= topk_threshold + log_probs = vals - softmax_max - log_exp_sum + + above_lo = is_topk & (log_probs > lo_lp) + count_gt_lo_lp += tl.sum(above_lo.to(tl.int32)) + + lp_above_lo = tl.where(above_lo, log_probs, float("inf")) + min_lp_above_lo = tl.minimum(min_lp_above_lo, tl.min(lp_above_lo)) + + at_or_below_hi = is_topk & (log_probs <= hi_lp) + lp_at_or_below_hi = tl.where(at_or_below_hi, log_probs, float("-inf")) + max_lp_at_or_below_hi = tl.maximum( + max_lp_at_or_below_hi, tl.max(lp_at_or_below_hi) + ) + + # For top-p, use min if there are values > lo, otherwise use max. + # This handles edge cases where lo/hi converge to the same side. + if tl.sum(count_gt_lo_lp) > 0 and min_lp_above_lo < float("inf"): + topp_log_threshold = min_lp_above_lo + else: + topp_log_threshold = max_lp_at_or_below_hi + + # ========================================================================= + # Phase 4: Apply combined mask + # ========================================================================= + + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + + # Apply top-k mask + keep = vals >= topk_threshold + + # Apply top-p mask + log_probs = vals - softmax_max - log_exp_sum + keep = keep & (log_probs >= topp_log_threshold) + + result = tl.where(keep, vals, mask_value) + tl.store(row_ptr + offs, result, mask=mask) + + +def apply_top_k_top_p_triton( + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + mask_value: float = float("-inf"), +) -> torch.Tensor: + """ + Apply combined top-k and top-p masking using Triton. + + Top-k is applied first (by logit value), then top-p is applied + to the remaining k values (by probability). + + Args: + logits: [n, vocab_size] float32 tensor, modified in-place + k: [n] int32 tensor of top-k values per row, or None to disable top-k + p: [n] float32 tensor of top-p values per row (0 to 1), + or None to disable top-p + mask_value: Value for masked positions (default: -inf) + + Returns: + The logits tensor (modified in-place) + """ + assert logits.ndim == 2 + assert logits.dtype == torch.float32 + assert logits.is_cuda + + n, vocab_size = logits.shape + + topk_enabled = k is not None + topp_enabled = p is not None + + if n == 0 or not (topk_enabled or topp_enabled): + return logits + + if k is not None: + assert k.ndim == 1 and k.shape[0] == n and k.is_cuda + k_ptr = k.to(torch.int32) + else: + k_ptr = logits # Dummy pointer (won't be read) + + if p is not None: + assert p.ndim == 1 and p.shape[0] == n and p.is_cuda + p_ptr = p.to(torch.float32) + else: + p_ptr = logits # Dummy pointer (won't be read) + + BLOCK_SIZE = 1024 + # K_ITERS must be large enough to distinguish adjacent logit values. + # With randn logits (range ~8), 20 iterations gives precision ~8/2^19 ≈ 1.5e-5 + K_ITERS = 18 + P_ITERS = 14 + + _topk_topp_kernel[(n,)]( + logits, + k_ptr, + p_ptr, + logits_stride=logits.stride(0), + vocab_size=vocab_size, + mask_value=mask_value, + BLOCK_SIZE=BLOCK_SIZE, + K_ITERS=K_ITERS, + P_ITERS=P_ITERS, + TOPK_ENABLED=topk_enabled, + TOPP_ENABLED=topp_enabled, + ) + + return logits