From a5fc250ea2b79e1b44e79c8419e8905578ccdc26 Mon Sep 17 00:00:00 2001 From: js_park Date: Tue, 23 Sep 2025 23:23:41 -0700 Subject: [PATCH 01/99] Attempt 1 Signed-off-by: js_park --- test.py | 23 +++ vllm/v1/sample/ops/topk_topp_sampler.py | 247 +++++++++++++++++++++--- 2 files changed, 243 insertions(+), 27 deletions(-) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 000000000000..4e4800402996 --- /dev/null +++ b/test.py @@ -0,0 +1,23 @@ +from vllm import LLM, SamplingParams + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +prompts = prompts * 64 +sampling_params = SamplingParams(temperature=0.8, top_p=0.99999, top_k=10) + +llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct") +# llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") + +outputs = llm.generate(prompts, sampling_params) + +for i, output in enumerate(outputs): + if i > 4: + break + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 747e52f2e589..f98f80f8c630 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from datetime import timedelta +from tkinter import NO from typing import Optional import torch import torch.nn as nn +import triton +import triton.language as tl +import time from packaging import version from vllm import envs @@ -19,6 +25,19 @@ is_flashinfer_available = True except ImportError: is_flashinfer_available = False + +def g_str(s): + return "\033[32m" + s + "\033[0m" +def r_str(s): + return "\033[31m" + s + "\033[0m" +def b_str(s): + return "\033[34m" + s + "\033[0m" +def y_str(s): + return "\033[33m" + s + "\033[0m" +def c_str(s): + return "\033[36m" + s + "\033[0m" +def m_str(s): + return "\033[35m" + s + "\033[0m" class TopKTopPSampler(nn.Module): @@ -120,46 +139,220 @@ def forward_cuda( return flashinfer_sample(logits.contiguous(), k, p, generators), None -def apply_top_k_top_p( +def original_apply_top_k_top_p( logits: torch.Tensor, k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. - - If a top-p is used, this function will sort the logits tensor, - which can be slow for large batches. - - The logits tensor may be updated in-place. """ + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + if p is None: if k is None: return logits # Avoid sorting vocab for top-k only case. - return apply_top_k_only(logits, k) + logits = apply_top_k_only(logits, k) + else: + if k is not None: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) +@triton.jit +def _topk_topp_kernel(LOGITS, PROBS, K, P, B, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr, + NUM_PIVOTS: tl.constexpr): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + for row_id in tl.range(pid, B, num_programs): + k = tl.load(K + row_id) + p = tl.load(P + row_id) + if not (k == N and p == 1.0): # All tokens are valid + max_logit = -float('inf') + min_logit = float('inf') + + max_prob = 0.0 + min_prob = 1.0 + + # First pass: compute max and min logits (for numerical stability) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS + offs_n, mask=mask_n, other=0.0) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + # Second pass: compute probabilities using softmax + # (This requires the max for numerical stability) + exp_logits_sum = 0.0 + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS + offs_n, mask=mask_n, other=-float('inf')) + + logits_tile_stable = logits_blk - max_logit # Numerical stability + exp_logits = tl.exp(logits_tile_stable) + exp_logits_sum += tl.sum(exp_logits) + tl.store(PROBS + offs_n, exp_logits) + + # Third pass: compute probabilities and update max and min probabilities + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + probs_blk = tl.load(PROBS + offs_n, mask=mask_n, other=0.0) + probs_blk = probs_blk / exp_logits_sum + max_prob = tl.maximum(max_prob, tl.max(probs_blk)) + min_prob = tl.minimum(min_prob, tl.min(probs_blk)) + tl.store(PROBS + offs_n, probs_blk) + + # Fourth passes: Search for pivots + num_iters = 0 + k_pivot = -float('inf') + k_pivots = tl.full((NUM_PIVOTS,), -float('inf'), dtype=tl.float32) + k_pivots_num = tl.full((NUM_PIVOTS,), 0, dtype=tl.uint32) + + p_pivot = 0.0 + p_pivots = tl.full((NUM_PIVOTS,), -float('inf'), dtype=tl.float32) + p_pivots_sum = tl.full((NUM_PIVOTS,), 0.0, dtype=tl.float32) + + while (k_pivot == -float('inf') or p_pivot == 0.0) and num_iters < 32: + k_pivots = (max_logit - min_logit) * tl.arange(1, NUM_PIVOTS + 1) / NUM_PIVOTS + min_logit + p_pivots = (max_prob - min_prob) * tl.arange(1, NUM_PIVOTS + 1) / NUM_PIVOTS + min_prob + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.load(PROBS + offs_n, mask=mask_n, other=0.0) + + logits_expanded = logits_blk[None, :] # shape: 1 x BLOCK_SIZE + k_pivots_expanded = k_pivots[:, None] # shape: NUM_PIVOTS x 1 + larger_mask = logits_expanded > k_pivots_expanded # shape: NUM_PIVOTS x BLOCK_SIZE + k_pivots_num += tl.sum(larger_mask, axis=1) # shape: NUM_PIVOTS + + probs_expanded = probs_blk[None, :] # shape: 1 x BLOCK_SIZE + p_pivots_expanded = p_pivots[:, None] # shape: NUM_PIVOTS x 1 + larger_mask = probs_expanded > p_pivots_expanded # shape: NUM_PIVOTS x BLOCK_SIZE + larger_probs = tl.where(larger_mask, probs_expanded, 0.0) # shape: NUM_PIVOTS x BLOCK_SIZE + p_pivots_sum += tl.sum(larger_probs, axis=1) # shape: NUM_PIVOTS + + exact_match_k = k_pivots_num == k + if tl.sum(exact_match_k) > 0: + matches = tl.where(exact_match_k, k_pivots, float('inf')) + k_pivot = tl.min(matches) + else: + smaller_mask = k_pivots_num < k + if tl.sum(smaller_mask) > 0: + small_indices = tl.where(smaller_mask, k_pivots, float('inf')) + max_logit = tl.min(small_indices) + larger_mask = k_pivots_num > k + if tl.sum(larger_mask) > 0: + large_indices = tl.where(larger_mask, k_pivots, -float('inf')) + min_logit = tl.max(large_indices) + + exact_match_p = tl.abs(p_pivots_sum - p) < 1e-6 + if tl.sum(exact_match_p) > 0: + match_indices = tl.where(exact_match_p, p_pivots, float('inf')) + p_pivot = tl.min(match_indices) + else: + smaller_mask = p_pivots_sum < p + if tl.sum(smaller_mask) > 0: + small_indices = tl.where(smaller_mask, p_pivots, float('inf')) + max_prob = tl.min(small_indices) + larger_mask = p_pivots_sum > p + if tl.sum(larger_mask) > 0: + large_indices = tl.where(larger_mask, p_pivots, -float('inf')) + min_prob = tl.max(large_indices) + # For the case where sum of existing probabilities does not hit p + if min_prob == max_prob: + p_pivot = min_prob + + num_iters += 1 + + # Fifth pass: Apply top-k and top-p masks + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.load(PROBS + offs_n, mask=mask_n, other=0.0) + logits_blk = tl.where(logits_blk > k_pivot, logits_blk, -float('inf')) + logits_blk = tl.where(probs_blk > p_pivot, logits_blk, -float('inf')) + tl.store(LOGITS + offs_n, logits_blk, mask=mask_n) + + +def triton_apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + batch_size, vocab_size = logits.shape + BLOCK_SIZE = 4096 + NUM_PROGRAMS = 128 + NUM_PIVOTS = 4 # Multi pivot search for smaller number of scans + NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE + probs = torch.zeros_like(logits) + _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + vocab_size, BLOCK_SIZE, NUM_TILES, NUM_PIVOTS) + return logits - if k is not None: - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - if p is not None: - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) +@torch.compile +def compiled_apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + return original_apply_top_k_top_p(logits, k, p) + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. + """ + torch.cuda.synchronize() + start_time = time.time() + batch_size, vocab_size = logits.shape + print(g_str("apply_top_k_top_p") + f" logits.shape: {batch_size} x {vocab_size}, p is None: {p is None}, k is None: {k is None}") + input_logits = logits.clone() + + # logits = original_apply_top_k_top_p(logits, k, p) + # logits = compiled_apply_top_k_top_p(logits, k, p) + logits = triton_apply_top_k_top_p(logits, k, p) + + torch.cuda.synchronize() + time_taken = time.time() - start_time + print(y_str(f"apply_top_k_top_p done in {time_taken} seconds")) + start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) + out_dir = "./sampler_input_output" + os.makedirs(out_dir, exist_ok=True) + out_path = f"{out_dir}/llama8b_{start_time_str}.pt" + torch.save({"input_logits": input_logits, "p": p, "k": k, "output_logits": logits}, out_path) return logits From c95041b31fea603307ab1304d1320a33c24f8729 Mon Sep 17 00:00:00 2001 From: js_park Date: Wed, 24 Sep 2025 00:49:27 -0700 Subject: [PATCH 02/99] Top k works? Signed-off-by: js_park --- test.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 157 +++++++++++++++++++----- 2 files changed, 129 insertions(+), 30 deletions(-) diff --git a/test.py b/test.py index 4e4800402996..5d4f4c99db00 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ "The future of AI is", ] prompts = prompts * 64 -sampling_params = SamplingParams(temperature=0.8, top_p=0.99999, top_k=10) +sampling_params = SamplingParams(temperature=0.8, top_k=10) llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct") # llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index f98f80f8c630..3f76345da60c 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -194,11 +194,14 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, max_prob = 0.0 min_prob = 1.0 + LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + row_id * N + # First pass: compute max and min logits (for numerical stability) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS + offs_n, mask=mask_n, other=0.0) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=0.0) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) @@ -209,41 +212,42 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS + offs_n, mask=mask_n, other=-float('inf')) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) logits_tile_stable = logits_blk - max_logit # Numerical stability exp_logits = tl.exp(logits_tile_stable) exp_logits_sum += tl.sum(exp_logits) - tl.store(PROBS + offs_n, exp_logits) + tl.store(PROBS_ROW + offs_n, exp_logits, mask=mask_n) # Third pass: compute probabilities and update max and min probabilities for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - probs_blk = tl.load(PROBS + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) probs_blk = probs_blk / exp_logits_sum max_prob = tl.maximum(max_prob, tl.max(probs_blk)) min_prob = tl.minimum(min_prob, tl.min(probs_blk)) - tl.store(PROBS + offs_n, probs_blk) + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) # Fourth passes: Search for pivots num_iters = 0 k_pivot = -float('inf') k_pivots = tl.full((NUM_PIVOTS,), -float('inf'), dtype=tl.float32) - k_pivots_num = tl.full((NUM_PIVOTS,), 0, dtype=tl.uint32) p_pivot = 0.0 p_pivots = tl.full((NUM_PIVOTS,), -float('inf'), dtype=tl.float32) - p_pivots_sum = tl.full((NUM_PIVOTS,), 0.0, dtype=tl.float32) while (k_pivot == -float('inf') or p_pivot == 0.0) and num_iters < 32: k_pivots = (max_logit - min_logit) * tl.arange(1, NUM_PIVOTS + 1) / NUM_PIVOTS + min_logit p_pivots = (max_prob - min_prob) * tl.arange(1, NUM_PIVOTS + 1) / NUM_PIVOTS + min_prob + + k_pivots_num = tl.full((NUM_PIVOTS,), 0, dtype=tl.uint32) + p_pivots_sum = tl.full((NUM_PIVOTS,), 0.0, dtype=tl.float32) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = tl.load(PROBS + offs_n, mask=mask_n, other=0.0) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) logits_expanded = logits_blk[None, :] # shape: 1 x BLOCK_SIZE k_pivots_expanded = k_pivots[:, None] # shape: NUM_PIVOTS x 1 @@ -263,26 +267,26 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, else: smaller_mask = k_pivots_num < k if tl.sum(smaller_mask) > 0: - small_indices = tl.where(smaller_mask, k_pivots, float('inf')) - max_logit = tl.min(small_indices) + matches = tl.where(smaller_mask, k_pivots, float('inf')) + max_logit = tl.min(matches) larger_mask = k_pivots_num > k if tl.sum(larger_mask) > 0: - large_indices = tl.where(larger_mask, k_pivots, -float('inf')) - min_logit = tl.max(large_indices) + matches = tl.where(larger_mask, k_pivots, -float('inf')) + min_logit = tl.max(matches) exact_match_p = tl.abs(p_pivots_sum - p) < 1e-6 if tl.sum(exact_match_p) > 0: - match_indices = tl.where(exact_match_p, p_pivots, float('inf')) - p_pivot = tl.min(match_indices) + matches = tl.where(exact_match_p, p_pivots, float('inf')) + p_pivot = tl.min(matches) else: smaller_mask = p_pivots_sum < p if tl.sum(smaller_mask) > 0: - small_indices = tl.where(smaller_mask, p_pivots, float('inf')) - max_prob = tl.min(small_indices) + matches = tl.where(smaller_mask, p_pivots, float('inf')) + max_prob = tl.min(matches) larger_mask = p_pivots_sum > p if tl.sum(larger_mask) > 0: - large_indices = tl.where(larger_mask, p_pivots, -float('inf')) - min_prob = tl.max(large_indices) + matches = tl.where(larger_mask, p_pivots, -float('inf')) + min_prob = tl.max(matches) # For the case where sum of existing probabilities does not hit p if min_prob == max_prob: p_pivot = min_prob @@ -293,12 +297,79 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = tl.load(PROBS + offs_n, mask=mask_n, other=0.0) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n) logits_blk = tl.where(logits_blk > k_pivot, logits_blk, -float('inf')) logits_blk = tl.where(probs_blk > p_pivot, logits_blk, -float('inf')) - tl.store(LOGITS + offs_n, logits_blk, mask=mask_n) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) +@triton.jit +def _topk_kernel(LOGITS, PROBS, K, B, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr, + NUM_PIVOTS: tl.constexpr): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + for row_id in tl.range(pid, B, num_programs): + k = tl.load(K + row_id) + if not (k == N): # All tokens are valid + max_logit = -float('inf') + min_logit = float('inf') + + LOGITS_ROW = LOGITS + row_id * N + + # First pass: compute max and min logits (for numerical stability) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=0.0) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + # Fourth passes: Search for pivots + num_iters = 0 + k_pivot = -float('inf') + k_pivots = tl.full((NUM_PIVOTS,), -float('inf'), dtype=tl.float32) + + while (k_pivot == -float('inf')) and num_iters < 32: + k_pivots = (max_logit - min_logit) * tl.arange(1, NUM_PIVOTS + 1) / NUM_PIVOTS + min_logit + + k_pivots_num = tl.full((NUM_PIVOTS,), 0, dtype=tl.uint32) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + + logits_expanded = logits_blk[None, :] # shape: 1 x BLOCK_SIZE + k_pivots_expanded = k_pivots[:, None] # shape: NUM_PIVOTS x 1 + larger_mask = logits_expanded > k_pivots_expanded # shape: NUM_PIVOTS x BLOCK_SIZE + k_pivots_num += tl.sum(larger_mask, axis=1) # shape: NUM_PIVOTS + + exact_match_k = k_pivots_num == k + if tl.sum(exact_match_k) > 0: + matches = tl.where(exact_match_k, k_pivots, float('inf')) + k_pivot = tl.min(matches) + else: + smaller_mask = k_pivots_num < k + if tl.sum(smaller_mask) > 0: + matches = tl.where(smaller_mask, k_pivots, float('inf')) + max_logit = tl.min(matches) + larger_mask = k_pivots_num > k + if tl.sum(larger_mask) > 0: + matches = tl.where(larger_mask, k_pivots, -float('inf')) + min_logit = tl.max(matches) + + num_iters += 1 + + # Fifth pass: Apply top-k mask + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) + logits_blk = tl.where(logits_blk > k_pivot, logits_blk, -float('inf')) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) def triton_apply_top_k_top_p( logits: torch.Tensor, @@ -308,12 +379,16 @@ def triton_apply_top_k_top_p( batch_size, vocab_size = logits.shape BLOCK_SIZE = 4096 NUM_PROGRAMS = 128 - NUM_PIVOTS = 4 # Multi pivot search for smaller number of scans + NUM_PIVOTS = 16 # Multi pivot search for smaller number of scans NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE probs = torch.zeros_like(logits) - _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + if p is None: + _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES, NUM_PIVOTS) - return logits + else: + _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + vocab_size, BLOCK_SIZE, NUM_TILES, NUM_PIVOTS) + return logits, probs @torch.compile def compiled_apply_top_k_top_p( @@ -335,19 +410,43 @@ def apply_top_k_top_p( The logits tensor may be updated in-place. """ - torch.cuda.synchronize() - start_time = time.time() + input_logits = logits.clone() + original_logits = original_apply_top_k_top_p(logits, k, p) + original_probs = torch.softmax(input_logits, dim=-1) + batch_size, vocab_size = logits.shape print(g_str("apply_top_k_top_p") + f" logits.shape: {batch_size} x {vocab_size}, p is None: {p is None}, k is None: {k is None}") - input_logits = logits.clone() + + torch.cuda.synchronize() + start_time = time.time() # logits = original_apply_top_k_top_p(logits, k, p) # logits = compiled_apply_top_k_top_p(logits, k, p) - logits = triton_apply_top_k_top_p(logits, k, p) + logits, probs = triton_apply_top_k_top_p(logits, k, p) torch.cuda.synchronize() time_taken = time.time() - start_time print(y_str(f"apply_top_k_top_p done in {time_taken} seconds")) + + # if not torch.allclose(probs, original_probs): + # print(r_str("Error: probs are not close")) + # print(f"probs: {probs}") + # print(f"original_probs: {original_probs}") + + logits[logits < -1e-6] = -1000 + original_logits[original_logits < -1e-6] = -1000 + if not torch.allclose(logits, original_logits): + print(r_str("Error: logits are not close")) + print(f"logits: {logits}") + print(f"original_logits: {original_logits}") + diff = (logits - original_logits).abs().flatten() + diff_nonzero = diff[diff > 1e-6] + print(f"diff_nonzero: {diff_nonzero}") + print(f"diff_nonzero.max(): {diff_nonzero.max()}") + print(f"diff_nonzero.min(): {diff_nonzero.min()}") + print(f"diff_nonzero.mean(): {diff_nonzero.mean()}") + print(f"diff_nonzero.std(): {diff_nonzero.std()}") + start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" os.makedirs(out_dir, exist_ok=True) From fe60b2236c712b6b336b6bf9f1ed4e848aecce35 Mon Sep 17 00:00:00 2001 From: js_park Date: Wed, 24 Sep 2025 00:53:03 -0700 Subject: [PATCH 03/99] Top k works? Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 40 ++++++++++--------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 3f76345da60c..fac7119dc3ec 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -307,8 +307,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, def _topk_kernel(LOGITS, PROBS, K, B, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr, - NUM_PIVOTS: tl.constexpr): + NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): @@ -330,36 +329,27 @@ def _topk_kernel(LOGITS, PROBS, K, B, # Fourth passes: Search for pivots num_iters = 0 - k_pivot = -float('inf') - k_pivots = tl.full((NUM_PIVOTS,), -float('inf'), dtype=tl.float32) + pivot_found = False + k_pivot = 0.0 - while (k_pivot == -float('inf')) and num_iters < 32: - k_pivots = (max_logit - min_logit) * tl.arange(1, NUM_PIVOTS + 1) / NUM_PIVOTS + min_logit - - k_pivots_num = tl.full((NUM_PIVOTS,), 0, dtype=tl.uint32) + while not pivot_found and num_iters < 32: + k_pivot = (max_logit - min_logit) / 2.0 + k_pivots_num = 0.0 + for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - logits_expanded = logits_blk[None, :] # shape: 1 x BLOCK_SIZE - k_pivots_expanded = k_pivots[:, None] # shape: NUM_PIVOTS x 1 - larger_mask = logits_expanded > k_pivots_expanded # shape: NUM_PIVOTS x BLOCK_SIZE - k_pivots_num += tl.sum(larger_mask, axis=1) # shape: NUM_PIVOTS + larger_mask = logits_blk > k_pivot + k_pivots_num += tl.sum(larger_mask) - exact_match_k = k_pivots_num == k - if tl.sum(exact_match_k) > 0: - matches = tl.where(exact_match_k, k_pivots, float('inf')) - k_pivot = tl.min(matches) + if k_pivots_num == k: + pivot_found = True + elif k_pivots_num < k: + min_logit = k_pivot else: - smaller_mask = k_pivots_num < k - if tl.sum(smaller_mask) > 0: - matches = tl.where(smaller_mask, k_pivots, float('inf')) - max_logit = tl.min(matches) - larger_mask = k_pivots_num > k - if tl.sum(larger_mask) > 0: - matches = tl.where(larger_mask, k_pivots, -float('inf')) - min_logit = tl.max(matches) + max_logit = k_pivot num_iters += 1 @@ -384,7 +374,7 @@ def triton_apply_top_k_top_p( probs = torch.zeros_like(logits) if p is None: _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES, NUM_PIVOTS) + vocab_size, BLOCK_SIZE, NUM_TILES) else: _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES, NUM_PIVOTS) From 74a18b5aaaa70db4375710d90cf687f8e5ea3d9f Mon Sep 17 00:00:00 2001 From: js_park Date: Wed, 24 Sep 2025 12:18:53 -0700 Subject: [PATCH 04/99] Tenary search Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 50 ++++++++++++++++--------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index fac7119dc3ec..0a5196f57be4 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -304,7 +304,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) @triton.jit -def _topk_kernel(LOGITS, PROBS, K, B, +def _topk_kernel(LOGITS, K, B, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr): @@ -315,6 +315,8 @@ def _topk_kernel(LOGITS, PROBS, K, B, if not (k == N): # All tokens are valid max_logit = -float('inf') min_logit = float('inf') + sum_logits = 0.0 + pow_sum_logits = 0.0 LOGITS_ROW = LOGITS + row_id * N @@ -327,29 +329,41 @@ def _topk_kernel(LOGITS, PROBS, K, B, max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - # Fourth passes: Search for pivots + # Fourth passes: Binary search for pivots num_iters = 0 - pivot_found = False - k_pivot = 0.0 + k_pivot = -float('inf') + k_pivot_0 = 0.0 + k_pivot_1 = 0.0 - while not pivot_found and num_iters < 32: - k_pivot = (max_logit - min_logit) / 2.0 - k_pivots_num = 0.0 + while k_pivot == -float('inf') and num_iters < 32: + k_pivot_0 = (max_logit - min_logit) * 1.0 / 3.0 + min_logit + k_pivot_1 = (max_logit - min_logit) * 2.0 / 3.0 + min_logit + k_pivots_num_0 = 0.0 + k_pivots_num_1 = 0.0 for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - larger_mask = logits_blk > k_pivot - k_pivots_num += tl.sum(larger_mask) - - if k_pivots_num == k: - pivot_found = True - elif k_pivots_num < k: - min_logit = k_pivot - else: - max_logit = k_pivot + larger_mask = logits_blk > k_pivot_0 + k_pivots_num_0 += tl.sum(larger_mask) + larger_mask = logits_blk > k_pivot_1 + k_pivots_num_1 += tl.sum(larger_mask) + + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_1 > k: + min_logit = k_pivot_1 + elif k_pivots_num_0 > k: + min_logit = k_pivot_0 + + if k_pivots_num_0 < k: + max_logit = k_pivot_0 + elif k_pivots_num_1 < k: + max_logit = k_pivot_1 num_iters += 1 @@ -369,11 +383,11 @@ def triton_apply_top_k_top_p( batch_size, vocab_size = logits.shape BLOCK_SIZE = 4096 NUM_PROGRAMS = 128 - NUM_PIVOTS = 16 # Multi pivot search for smaller number of scans NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE + NUM_PIVOTS = 16 # Multi pivot search for smaller number of scans probs = torch.zeros_like(logits) if p is None: - _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, + _topk_kernel[(NUM_PROGRAMS,)](logits, k, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES) else: _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, From 7502c06488d5f281d54fb6cc918507bd51ca4b68 Mon Sep 17 00:00:00 2001 From: js_park Date: Wed, 24 Sep 2025 18:32:08 -0700 Subject: [PATCH 05/99] Quadruple Search Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 53 +++++++++++++------------ 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 0a5196f57be4..6beccf34da83 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -315,8 +315,6 @@ def _topk_kernel(LOGITS, K, B, if not (k == N): # All tokens are valid max_logit = -float('inf') min_logit = float('inf') - sum_logits = 0.0 - pow_sum_logits = 0.0 LOGITS_ROW = LOGITS + row_id * N @@ -329,32 +327,35 @@ def _topk_kernel(LOGITS, K, B, max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - # Fourth passes: Binary search for pivots - num_iters = 0 + # Second passes: Binary search for pivots k_pivot = -float('inf') - k_pivot_0 = 0.0 - k_pivot_1 = 0.0 - - while k_pivot == -float('inf') and num_iters < 32: - k_pivot_0 = (max_logit - min_logit) * 1.0 / 3.0 + min_logit - k_pivot_1 = (max_logit - min_logit) * 2.0 / 3.0 + min_logit - k_pivots_num_0 = 0.0 - k_pivots_num_1 = 0.0 + num_iters = 0 + + while k_pivot == -float('inf') and num_iters < 18: + k_pivot_0 = (max_logit - min_logit) * 1.0 / 4.0 + min_logit + k_pivot_1 = (max_logit - min_logit) * 2.0 / 4.0 + min_logit + k_pivot_2 = (max_logit - min_logit) * 3.0 / 4.0 + min_logit + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - larger_mask = logits_blk > k_pivot_0 - k_pivots_num_0 += tl.sum(larger_mask) - larger_mask = logits_blk > k_pivot_1 - k_pivots_num_1 += tl.sum(larger_mask) + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) if k_pivots_num_0 == k: - k_pivot = k_pivot_0 + k_pivot = k_pivot_0 elif k_pivots_num_1 == k: k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + elif k_pivots_num_2 > k: + min_logit = k_pivot_2 elif k_pivots_num_1 > k: min_logit = k_pivot_1 elif k_pivots_num_0 > k: @@ -364,10 +365,14 @@ def _topk_kernel(LOGITS, K, B, max_logit = k_pivot_0 elif k_pivots_num_1 < k: max_logit = k_pivot_1 - + elif k_pivots_num_2 < k: + max_logit = k_pivot_2 + num_iters += 1 + if num_iters >= 18: + k_pivot = k_pivot_0 - # Fifth pass: Apply top-k mask + # Third pass: Apply top-k mask for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N @@ -415,7 +420,7 @@ def apply_top_k_top_p( The logits tensor may be updated in-place. """ input_logits = logits.clone() - original_logits = original_apply_top_k_top_p(logits, k, p) + original_logits = original_apply_top_k_top_p(input_logits, k, p) original_probs = torch.softmax(input_logits, dim=-1) batch_size, vocab_size = logits.shape @@ -436,13 +441,11 @@ def apply_top_k_top_p( # print(r_str("Error: probs are not close")) # print(f"probs: {probs}") # print(f"original_probs: {original_probs}") - - logits[logits < -1e-6] = -1000 - original_logits[original_logits < -1e-6] = -1000 + + print(f"logits: {logits}") + print(f"original_logits: {original_logits}") if not torch.allclose(logits, original_logits): print(r_str("Error: logits are not close")) - print(f"logits: {logits}") - print(f"original_logits: {original_logits}") diff = (logits - original_logits).abs().flatten() diff_nonzero = diff[diff > 1e-6] print(f"diff_nonzero: {diff_nonzero}") From 360e2343614f83949f509df39d01260f47007939 Mon Sep 17 00:00:00 2001 From: js_park Date: Wed, 24 Sep 2025 21:33:48 -0700 Subject: [PATCH 06/99] Quadruple Search Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 34 +++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 6beccf34da83..086d27e9c081 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -304,7 +304,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) @triton.jit -def _topk_kernel(LOGITS, K, B, +def _topk_kernel(LOGITS, PROBS, K, P, B, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr): @@ -317,7 +317,17 @@ def _topk_kernel(LOGITS, K, B, min_logit = float('inf') LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + row_id * N + + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < N + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.mean(logits_blk) + std_logit = tl.std(logits_blk) + outlier_pivot = avg_logit + 3 * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits (for numerical stability) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -326,8 +336,16 @@ def _topk_kernel(LOGITS, K, B, max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = logits_blk > outlier_pivot + num_blk_outliers = tl.sum(outlier_mask) + num_outliers += num_blk_outliers + outlier_idx = tl.where(outlier_mask, offs, -1) + gathered_outliers = tl.gather(logits_blk, outlier_idx) + tl.store(PROBS_ROW + num_outliers + tl.arange(0, num_blk_outliers), + gathered_outliers) + - # Second passes: Binary search for pivots + # Second passes: Quaternary search for pivots (nlog_4(n)) k_pivot = -float('inf') num_iters = 0 @@ -390,9 +408,9 @@ def triton_apply_top_k_top_p( NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE NUM_PIVOTS = 16 # Multi pivot search for smaller number of scans - probs = torch.zeros_like(logits) + probs = torch.full_like(logits, -float('inf')) if p is None: - _topk_kernel[(NUM_PROGRAMS,)](logits, k, batch_size, + _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES) else: _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, @@ -442,17 +460,13 @@ def apply_top_k_top_p( # print(f"probs: {probs}") # print(f"original_probs: {original_probs}") - print(f"logits: {logits}") - print(f"original_logits: {original_logits}") if not torch.allclose(logits, original_logits): print(r_str("Error: logits are not close")) + print(f"logits: {logits}") + print(f"original_logits: {original_logits}") diff = (logits - original_logits).abs().flatten() diff_nonzero = diff[diff > 1e-6] print(f"diff_nonzero: {diff_nonzero}") - print(f"diff_nonzero.max(): {diff_nonzero.max()}") - print(f"diff_nonzero.min(): {diff_nonzero.min()}") - print(f"diff_nonzero.mean(): {diff_nonzero.mean()}") - print(f"diff_nonzero.std(): {diff_nonzero.std()}") start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" From 11bd61fcdef3676aa9368c96511329eda89250d7 Mon Sep 17 00:00:00 2001 From: js_park Date: Wed, 24 Sep 2025 22:07:08 -0700 Subject: [PATCH 07/99] Added outliers Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 086d27e9c081..e74922208a9b 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -323,10 +323,11 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, offs = tl.arange(0, BLOCK_SIZE) mask_n = offs < N logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.mean(logits_blk) - std_logit = tl.std(logits_blk) + avg_logit = tl.sum(logits_blk) / N + sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - outlier_pivot = avg_logit + 3 * std_logit + outlier_pivot = avg_logit + 2.5 * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits (for numerical stability) for i in range(0, NUM_TILES): @@ -340,10 +341,14 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, num_blk_outliers = tl.sum(outlier_mask) num_outliers += num_blk_outliers outlier_idx = tl.where(outlier_mask, offs, -1) - gathered_outliers = tl.gather(logits_blk, outlier_idx) - tl.store(PROBS_ROW + num_outliers + tl.arange(0, num_blk_outliers), - gathered_outliers) + gathered_outliers = tl.gather(logits_blk, outlier_idx, axis=0) + off_outliers = tl.arange(0, BLOCK_SIZE) + mask_outliers = off_outliers < num_blk_outliers + tl.store(PROBS_ROW + num_outliers + off_outliers, + gathered_outliers, mask=mask_outliers) + if num_outliers > k: + min_logit = outlier_pivot # Second passes: Quaternary search for pivots (nlog_4(n)) k_pivot = -float('inf') From a922b45e941e21b9e4079eff0e302d0c4d387025 Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 14:23:34 -0700 Subject: [PATCH 08/99] Added gather Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 38 +++++++++++++++---------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e74922208a9b..4c871c6ff0d1 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -327,28 +327,32 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, sq_avg_logit = tl.sum(logits_blk * logits_blk) / N std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - outlier_pivot = avg_logit + 2.5 * std_logit + outlier_pivot = avg_logit + 3 * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) - # First pass: compute max and min logits (for numerical stability) + # First pass: compute max and min logits and gather outliers for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=0.0) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - outlier_mask = logits_blk > outlier_pivot + outlier_mask = (logits_blk > outlier_pivot) & mask_n num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) num_outliers += num_blk_outliers - outlier_idx = tl.where(outlier_mask, offs, -1) - gathered_outliers = tl.gather(logits_blk, outlier_idx, axis=0) - off_outliers = tl.arange(0, BLOCK_SIZE) - mask_outliers = off_outliers < num_blk_outliers - tl.store(PROBS_ROW + num_outliers + off_outliers, - gathered_outliers, mask=mask_outliers) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_ROW + write_pos, logits_blk, mask=mask_n) if num_outliers > k: - min_logit = outlier_pivot + # min_logit = outlier_pivot + search_addr = PROBS_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + else: + search_addr = LOGITS_ROW + search_range = N + search_iters = NUM_TILES # Second passes: Quaternary search for pivots (nlog_4(n)) k_pivot = -float('inf') @@ -362,28 +366,29 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - for i in range(0, NUM_TILES): + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + # Check if any of the pivots are equal to k if k_pivots_num_0 == k: k_pivot = k_pivot_0 elif k_pivots_num_1 == k: k_pivot = k_pivot_1 elif k_pivots_num_2 == k: k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we updatae the range elif k_pivots_num_2 > k: min_logit = k_pivot_2 elif k_pivots_num_1 > k: min_logit = k_pivot_1 elif k_pivots_num_0 > k: min_logit = k_pivot_0 - if k_pivots_num_0 < k: max_logit = k_pivot_0 elif k_pivots_num_1 < k: @@ -414,12 +419,15 @@ def triton_apply_top_k_top_p( NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE NUM_PIVOTS = 16 # Multi pivot search for smaller number of scans probs = torch.full_like(logits, -float('inf')) + print(f"Input logits: {logits}") if p is None: _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES) else: _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES, NUM_PIVOTS) + print(f"Output logits: {logits}") + print(f"Output probs: {probs}") return logits, probs @torch.compile From 6f39f20903b9227d5c6b607d231c700d34c6de23 Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 14:25:55 -0700 Subject: [PATCH 09/99] Added gather Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 199 +++++++++--------------- 1 file changed, 75 insertions(+), 124 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 4c871c6ff0d1..a0de67c71323 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -178,131 +178,108 @@ def original_apply_top_k_top_p( @triton.jit def _topk_topp_kernel(LOGITS, PROBS, K, P, B, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr, - NUM_PIVOTS: tl.constexpr): + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): k = tl.load(K + row_id) - p = tl.load(P + row_id) - if not (k == N and p == 1.0): # All tokens are valid + if not (k == N): # All tokens are valid max_logit = -float('inf') min_logit = float('inf') - max_prob = 0.0 - min_prob = 1.0 - LOGITS_ROW = LOGITS + row_id * N PROBS_ROW = PROBS + row_id * N - # First pass: compute max and min logits (for numerical stability) + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < N + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / N + sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + + outlier_pivot = avg_logit + 2.7 * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) + # First pass: compute max and min logits and gather outliers for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=0.0) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - - # Second pass: compute probabilities using softmax - # (This requires the max for numerical stability) - exp_logits_sum = 0.0 - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + outlier_mask = (logits_blk > outlier_pivot) & mask_n + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_ROW + write_pos, logits_blk, mask=mask_n) - logits_tile_stable = logits_blk - max_logit # Numerical stability - exp_logits = tl.exp(logits_tile_stable) - exp_logits_sum += tl.sum(exp_logits) - tl.store(PROBS_ROW + offs_n, exp_logits, mask=mask_n) - - # Third pass: compute probabilities and update max and min probabilities - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - probs_blk = probs_blk / exp_logits_sum - max_prob = tl.maximum(max_prob, tl.max(probs_blk)) - min_prob = tl.minimum(min_prob, tl.min(probs_blk)) - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + if num_outliers > k: + # min_logit = outlier_pivot + search_addr = PROBS_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + else: + search_addr = LOGITS_ROW + search_range = N + search_iters = NUM_TILES - # Fourth passes: Search for pivots - num_iters = 0 + # Second passes: Quaternary search for pivots (nlog_4(n)) k_pivot = -float('inf') - k_pivots = tl.full((NUM_PIVOTS,), -float('inf'), dtype=tl.float32) - - p_pivot = 0.0 - p_pivots = tl.full((NUM_PIVOTS,), -float('inf'), dtype=tl.float32) - - while (k_pivot == -float('inf') or p_pivot == 0.0) and num_iters < 32: - k_pivots = (max_logit - min_logit) * tl.arange(1, NUM_PIVOTS + 1) / NUM_PIVOTS + min_logit - p_pivots = (max_prob - min_prob) * tl.arange(1, NUM_PIVOTS + 1) / NUM_PIVOTS + min_prob - - k_pivots_num = tl.full((NUM_PIVOTS,), 0, dtype=tl.uint32) - p_pivots_sum = tl.full((NUM_PIVOTS,), 0.0, dtype=tl.float32) - for i in range(0, NUM_TILES): + num_iters = 0 + + while k_pivot == -float('inf') and num_iters < 18: + k_pivot_0 = (max_logit - min_logit) * 1.0 / 4.0 + min_logit + k_pivot_1 = (max_logit - min_logit) * 2.0 / 4.0 + min_logit + k_pivot_2 = (max_logit - min_logit) * 3.0 / 4.0 + min_logit + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - - logits_expanded = logits_blk[None, :] # shape: 1 x BLOCK_SIZE - k_pivots_expanded = k_pivots[:, None] # shape: NUM_PIVOTS x 1 - larger_mask = logits_expanded > k_pivots_expanded # shape: NUM_PIVOTS x BLOCK_SIZE - k_pivots_num += tl.sum(larger_mask, axis=1) # shape: NUM_PIVOTS - - probs_expanded = probs_blk[None, :] # shape: 1 x BLOCK_SIZE - p_pivots_expanded = p_pivots[:, None] # shape: NUM_PIVOTS x 1 - larger_mask = probs_expanded > p_pivots_expanded # shape: NUM_PIVOTS x BLOCK_SIZE - larger_probs = tl.where(larger_mask, probs_expanded, 0.0) # shape: NUM_PIVOTS x BLOCK_SIZE - p_pivots_sum += tl.sum(larger_probs, axis=1) # shape: NUM_PIVOTS - - exact_match_k = k_pivots_num == k - if tl.sum(exact_match_k) > 0: - matches = tl.where(exact_match_k, k_pivots, float('inf')) - k_pivot = tl.min(matches) - else: - smaller_mask = k_pivots_num < k - if tl.sum(smaller_mask) > 0: - matches = tl.where(smaller_mask, k_pivots, float('inf')) - max_logit = tl.min(matches) - larger_mask = k_pivots_num > k - if tl.sum(larger_mask) > 0: - matches = tl.where(larger_mask, k_pivots, -float('inf')) - min_logit = tl.max(matches) - - exact_match_p = tl.abs(p_pivots_sum - p) < 1e-6 - if tl.sum(exact_match_p) > 0: - matches = tl.where(exact_match_p, p_pivots, float('inf')) - p_pivot = tl.min(matches) - else: - smaller_mask = p_pivots_sum < p - if tl.sum(smaller_mask) > 0: - matches = tl.where(smaller_mask, p_pivots, float('inf')) - max_prob = tl.min(matches) - larger_mask = p_pivots_sum > p - if tl.sum(larger_mask) > 0: - matches = tl.where(larger_mask, p_pivots, -float('inf')) - min_prob = tl.max(matches) - # For the case where sum of existing probabilities does not hit p - if min_prob == max_prob: - p_pivot = min_prob + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + + # Check if any of the pivots are equal to k + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we updatae the range + elif k_pivots_num_2 > k: + min_logit = k_pivot_2 + elif k_pivots_num_1 > k: + min_logit = k_pivot_1 + elif k_pivots_num_0 > k: + min_logit = k_pivot_0 + if k_pivots_num_0 < k: + max_logit = k_pivot_0 + elif k_pivots_num_1 < k: + max_logit = k_pivot_1 + elif k_pivots_num_2 < k: + max_logit = k_pivot_2 + num_iters += 1 + if num_iters >= 18: + k_pivot = k_pivot_0 - # Fifth pass: Apply top-k and top-p masks + # Third pass: Apply top-k mask for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n) logits_blk = tl.where(logits_blk > k_pivot, logits_blk, -float('inf')) - logits_blk = tl.where(probs_blk > p_pivot, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) - @triton.jit def _topk_kernel(LOGITS, PROBS, K, P, B, N: tl.constexpr, @@ -319,16 +296,6 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, LOGITS_ROW = LOGITS + row_id * N PROBS_ROW = PROBS + row_id * N - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < N - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / N - sq_avg_logit = tl.sum(logits_blk * logits_blk) / N - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - outlier_pivot = avg_logit + 3 * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits and gather outliers for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -337,22 +304,6 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_ROW + write_pos, logits_blk, mask=mask_n) - - if num_outliers > k: - # min_logit = outlier_pivot - search_addr = PROBS_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - else: - search_addr = LOGITS_ROW - search_range = N - search_iters = NUM_TILES # Second passes: Quaternary search for pivots (nlog_4(n)) k_pivot = -float('inf') @@ -366,9 +317,9 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - for i in range(0, search_iters): + for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range + mask_n = offs_n < N logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) From 30033c220809b448c02602c63dcea3e6a119338b Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 14:46:41 -0700 Subject: [PATCH 10/99] 0.00115 for topk Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 201 +++++++++++++----------- 1 file changed, 105 insertions(+), 96 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index a0de67c71323..7c9064d45995 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -184,94 +184,96 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): - k = tl.load(K + row_id) - if not (k == N): # All tokens are valid - max_logit = -float('inf') - min_logit = float('inf') - - LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + row_id * N - - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < N - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / N - sq_avg_logit = tl.sum(logits_blk * logits_blk) / N - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - outlier_pivot = avg_logit + 2.7 * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - # First pass: compute max and min logits and gather outliers - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) - - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_ROW + write_pos, logits_blk, mask=mask_n) - - if num_outliers > k: - # min_logit = outlier_pivot - search_addr = PROBS_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - else: - search_addr = LOGITS_ROW - search_range = N - search_iters = NUM_TILES - - # Second passes: Quaternary search for pivots (nlog_4(n)) - k_pivot = -float('inf') - num_iters = 0 - - while k_pivot == -float('inf') and num_iters < 18: - k_pivot_0 = (max_logit - min_logit) * 1.0 / 4.0 + min_logit - k_pivot_1 = (max_logit - min_logit) * 2.0 / 4.0 + min_logit - k_pivot_2 = (max_logit - min_logit) * 3.0 / 4.0 + min_logit - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): + k_pivot = float('inf') + if K is not None: + k = tl.load(K + row_id) + if not (k == N): # All tokens are valid + max_logit = -float('inf') + min_logit = float('inf') + + LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + row_id * N + + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < N + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / N + sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + + outlier_pivot = avg_logit + 2.7 * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) + # First pass: compute max and min logits and gather outliers + for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) - - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = (logits_blk > outlier_pivot) & mask_n + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_ROW + write_pos, logits_blk, mask=mask_n) + + if num_outliers > k: + # min_logit = outlier_pivot + search_addr = PROBS_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + else: + search_addr = LOGITS_ROW + search_range = N + search_iters = NUM_TILES - # Check if any of the pivots are equal to k - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we updatae the range - elif k_pivots_num_2 > k: - min_logit = k_pivot_2 - elif k_pivots_num_1 > k: - min_logit = k_pivot_1 - elif k_pivots_num_0 > k: - min_logit = k_pivot_0 - if k_pivots_num_0 < k: - max_logit = k_pivot_0 - elif k_pivots_num_1 < k: - max_logit = k_pivot_1 - elif k_pivots_num_2 < k: - max_logit = k_pivot_2 + # Second passes: Quaternary search for pivots (nlog_4(n)) - num_iters += 1 - if num_iters >= 18: - k_pivot = k_pivot_0 + num_iters = 0 + + while k_pivot == float('inf') and num_iters < 18: + k_pivot_0 = (max_logit - min_logit) * 1.0 / 4.0 + min_logit + k_pivot_1 = (max_logit - min_logit) * 2.0 / 4.0 + min_logit + k_pivot_2 = (max_logit - min_logit) * 3.0 / 4.0 + min_logit + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) + + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + + # Check if any of the pivots are equal to k + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we updatae the range + elif k_pivots_num_2 > k: + min_logit = k_pivot_2 + elif k_pivots_num_1 > k: + min_logit = k_pivot_1 + elif k_pivots_num_0 > k: + min_logit = k_pivot_0 + if k_pivots_num_0 < k: + max_logit = k_pivot_0 + elif k_pivots_num_1 < k: + max_logit = k_pivot_1 + elif k_pivots_num_2 < k: + max_logit = k_pivot_2 + + num_iters += 1 + if num_iters >= 18: + k_pivot = k_pivot_0 # Third pass: Apply top-k mask for i in range(0, NUM_TILES): @@ -285,6 +287,8 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr): + if K is None: + return pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): @@ -294,7 +298,13 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, min_logit = float('inf') LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + row_id * N + + # Zeroth pass: Compute avg logit from a sample block + # This may cause neumerical instability when N < BLOCK_SIZE + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < N + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / N # First pass: compute max and min logits and gather outliers for i in range(0, NUM_TILES): @@ -306,10 +316,10 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, min_logit = tl.minimum(min_logit, tl.min(logits_blk)) # Second passes: Quaternary search for pivots (nlog_4(n)) - k_pivot = -float('inf') + k_pivot = float('inf') num_iters = 0 - while k_pivot == -float('inf') and num_iters < 18: + while k_pivot == float('inf') and num_iters < 18: k_pivot_0 = (max_logit - min_logit) * 1.0 / 4.0 + min_logit k_pivot_1 = (max_logit - min_logit) * 2.0 / 4.0 + min_logit k_pivot_2 = (max_logit - min_logit) * 3.0 / 4.0 + min_logit @@ -320,7 +330,7 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) @@ -368,17 +378,16 @@ def triton_apply_top_k_top_p( BLOCK_SIZE = 4096 NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE - NUM_PIVOTS = 16 # Multi pivot search for smaller number of scans probs = torch.full_like(logits, -float('inf')) - print(f"Input logits: {logits}") + # print(f"Input logits: {logits}") if p is None: _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES) - else: - _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES, NUM_PIVOTS) - print(f"Output logits: {logits}") - print(f"Output probs: {probs}") + # else: + # _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + # vocab_size, BLOCK_SIZE, NUM_TILES) + # print(f"Output logits: {logits}") + # print(f"Output probs: {probs}") return logits, probs @torch.compile From 2987617568d83f8f506de4347aa63ba8452b1435 Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 14:47:20 -0700 Subject: [PATCH 11/99] 0.00115 for topk Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 7c9064d45995..0eaf3731f948 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -381,7 +381,7 @@ def triton_apply_top_k_top_p( probs = torch.full_like(logits, -float('inf')) # print(f"Input logits: {logits}") if p is None: - _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES) # else: # _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, From ba5b98b5ff59ad0b46f5108b482492f9b1e29b2f Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 16:27:28 -0700 Subject: [PATCH 12/99] topk working, adding topp: Signed-off-by: js_park --- test.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 325 ++++++++++++++---------- 2 files changed, 194 insertions(+), 133 deletions(-) diff --git a/test.py b/test.py index 5d4f4c99db00..b335f1a6c5fa 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ "The future of AI is", ] prompts = prompts * 64 -sampling_params = SamplingParams(temperature=0.8, top_k=10) +sampling_params = SamplingParams(temperature=0.8, top_k=10, top_p=0.95) llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct") # llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 0eaf3731f948..7f8ceb94874c 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -177,160 +177,192 @@ def original_apply_top_k_top_p( return logits @triton.jit -def _topk_topp_kernel(LOGITS, PROBS, K, P, B, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr): +def _topk_kernel(LOGITS, PROBS, K, B, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): - k_pivot = float('inf') - if K is not None: - k = tl.load(K + row_id) - if not (k == N): # All tokens are valid - max_logit = -float('inf') - min_logit = float('inf') - - LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + row_id * N - - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < N - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / N - sq_avg_logit = tl.sum(logits_blk * logits_blk) / N - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - outlier_pivot = avg_logit + 2.7 * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - # First pass: compute max and min logits and gather outliers - for i in range(0, NUM_TILES): + k_pivot = -float('inf') + p_pivot = -float('inf') + + LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + row_id * N + + search_addr = LOGITS_ROW + search_range = N + search_iters = NUM_TILES + + max_logit = -float('inf') + + k = tl.load(K + row_id) + if not (k == N): # All tokens are valid + min_logit = float('inf') + + # Zeroth pass: Compute avg and std from a sample block + # May produce incorrect results if N < BLOCK_SIZE + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < N + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / N + sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + + outlier_pivot = avg_logit + 2.8 * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) + # First pass: compute max and min logits and gather outliers + for i in range(0,search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = (logits_blk > outlier_pivot) & mask_n + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) + + max_range = max_logit + min_range = min_logit + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_addr = PROBS_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + + # Second passes: Quaternary search for pivots (nlog_4(n)) + num_iters = 0 + while k_pivot == -float('inf') and num_iters < 18: + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) - - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_ROW + write_pos, logits_blk, mask=mask_n) - - if num_outliers > k: - # min_logit = outlier_pivot - search_addr = PROBS_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - else: - search_addr = LOGITS_ROW - search_range = N - search_iters = NUM_TILES + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) + + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - # Second passes: Quaternary search for pivots (nlog_4(n)) + # Check if any of the pivots are equal to k + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we updatae the range + elif k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 - num_iters = 0 - - while k_pivot == float('inf') and num_iters < 18: - k_pivot_0 = (max_logit - min_logit) * 1.0 / 4.0 + min_logit - k_pivot_1 = (max_logit - min_logit) * 2.0 / 4.0 + min_logit - k_pivot_2 = (max_logit - min_logit) * 3.0 / 4.0 + min_logit - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) - - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - - # Check if any of the pivots are equal to k - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we updatae the range - elif k_pivots_num_2 > k: - min_logit = k_pivot_2 - elif k_pivots_num_1 > k: - min_logit = k_pivot_1 - elif k_pivots_num_0 > k: - min_logit = k_pivot_0 - if k_pivots_num_0 < k: - max_logit = k_pivot_0 - elif k_pivots_num_1 < k: - max_logit = k_pivot_1 - elif k_pivots_num_2 < k: - max_logit = k_pivot_2 - - num_iters += 1 - if num_iters >= 18: - k_pivot = k_pivot_0 - - # Third pass: Apply top-k mask + num_iters += 1 + if num_iters >= 18: + k_pivot = k_pivot_0 + + # Third pass: Apply top-k mask + if k_pivot != -float('inf') or p_pivot != -float('inf'): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - logits_blk = tl.where(logits_blk > k_pivot, logits_blk, -float('inf')) + mask = (logits_blk > k_pivot) & (logits_blk > p_pivot) + logits_blk = tl.where(mask, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + + @triton.jit -def _topk_kernel(LOGITS, PROBS, K, P, B, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr): - if K is None: - return +def _topk_topp_kernel(LOGITS, PROBS, K, P, B, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): + k_pivot = -float('inf') + p_pivot = -float('inf') + + LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + row_id * N + + search_addr = LOGITS_ROW + search_range = N + search_iters = NUM_TILES + + max_logit = -float('inf') + k = tl.load(K + row_id) if not (k == N): # All tokens are valid - max_logit = -float('inf') min_logit = float('inf') - LOGITS_ROW = LOGITS + row_id * N - - # Zeroth pass: Compute avg logit from a sample block - # This may cause neumerical instability when N < BLOCK_SIZE + # Zeroth pass: Compute avg and std from a sample block + # May produce incorrect results if N < BLOCK_SIZE offs = tl.arange(0, BLOCK_SIZE) mask_n = offs < N logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) avg_logit = tl.sum(logits_blk) / N + sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + outlier_pivot = avg_logit + 2.8 * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits and gather outliers - for i in range(0, NUM_TILES): + for i in range(0,search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = (logits_blk > outlier_pivot) & mask_n + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) + + max_range = max_logit + min_range = min_logit + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_addr = PROBS_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) # Second passes: Quaternary search for pivots (nlog_4(n)) - k_pivot = float('inf') num_iters = 0 - - while k_pivot == float('inf') and num_iters < 18: - k_pivot_0 = (max_logit - min_logit) * 1.0 / 4.0 + min_logit - k_pivot_1 = (max_logit - min_logit) * 2.0 / 4.0 + min_logit - k_pivot_2 = (max_logit - min_logit) * 3.0 / 4.0 + min_logit + while k_pivot == -float('inf') and num_iters < 18: + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - for i in range(0, NUM_TILES): + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) @@ -345,28 +377,54 @@ def _topk_kernel(LOGITS, PROBS, K, P, B, k_pivot = k_pivot_2 # If none of the pivots are equal to k, we updatae the range elif k_pivots_num_2 > k: - min_logit = k_pivot_2 + min_range = k_pivot_2 elif k_pivots_num_1 > k: - min_logit = k_pivot_1 + min_range = k_pivot_1 elif k_pivots_num_0 > k: - min_logit = k_pivot_0 + min_range = k_pivot_0 if k_pivots_num_0 < k: - max_logit = k_pivot_0 + max_range = k_pivot_0 elif k_pivots_num_1 < k: - max_logit = k_pivot_1 + max_range = k_pivot_1 elif k_pivots_num_2 < k: - max_logit = k_pivot_2 + max_range = k_pivot_2 num_iters += 1 if num_iters >= 18: k_pivot = k_pivot_0 - # Third pass: Apply top-k mask + p = tl.load(P + row_id) + if p != 1.0: + max_probs = 0.0 + min_probs = 1.0 + sum_exp_logits = 0.0 + + # Third pass: Compute exp logits and sum + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) + probs_blk = probs_blk - max_probs + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + # Fourth pass: Compute probs (softmax) + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) + probs_blk = probs_blk / sum_exp_logits + min_probs = tl.minimum(min_probs, tl.min(probs_blk)) + max_probs = tl.maximum(max_probs, tl.max(probs_blk)) + + # Sixth pass: Apply top-k mask + if k_pivot != -float('inf') or p_pivot != -float('inf'): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - logits_blk = tl.where(logits_blk > k_pivot, logits_blk, -float('inf')) + mask = (logits_blk > k_pivot) & (logits_blk > p_pivot) + logits_blk = tl.where(mask, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) def triton_apply_top_k_top_p( @@ -379,16 +437,19 @@ def triton_apply_top_k_top_p( NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE probs = torch.full_like(logits, -float('inf')) + print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " + f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " + f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") # print(f"Input logits: {logits}") - if p is None: - _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + if p is None and k is not None: + _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES) - # else: - # _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, - # vocab_size, BLOCK_SIZE, NUM_TILES) + else: + _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + vocab_size, BLOCK_SIZE, NUM_TILES) # print(f"Output logits: {logits}") # print(f"Output probs: {probs}") - return logits, probs + return logits @torch.compile def compiled_apply_top_k_top_p( @@ -422,7 +483,7 @@ def apply_top_k_top_p( # logits = original_apply_top_k_top_p(logits, k, p) # logits = compiled_apply_top_k_top_p(logits, k, p) - logits, probs = triton_apply_top_k_top_p(logits, k, p) + logits = triton_apply_top_k_top_p(logits, k, p) torch.cuda.synchronize() time_taken = time.time() - start_time From 46bcc7df1a1f9f71646be8e503c11a148d6a3b5a Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 16:47:11 -0700 Subject: [PATCH 13/99] Wrong results Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 72 +++++++++++++++++++++---- 1 file changed, 61 insertions(+), 11 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 7f8ceb94874c..a9863cbcc032 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -309,6 +309,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, search_iters = NUM_TILES max_logit = -float('inf') + avg_logit = -float('inf') k = tl.load(K + row_id) if not (k == N): # All tokens are valid @@ -403,19 +404,71 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) probs_blk = probs_blk - max_probs probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) # Fourth pass: Compute probs (softmax) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=avg_logit) probs_blk = probs_blk / sum_exp_logits min_probs = tl.minimum(min_probs, tl.min(probs_blk)) max_probs = tl.maximum(max_probs, tl.max(probs_blk)) + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + + max_range = max_probs + min_range = min_probs + + num_iters = 0 + while p_pivot == -float('inf') and num_iters < 18: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + p_pivots_sum_1 = 0.0 + p_pivots_sum_2 = 0.0 + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=avg_logit) + + larger_mask_0 = probs_blk > p_pivot_0 + larger_mask_1 = probs_blk > p_pivot_1 + larger_mask_2 = probs_blk > p_pivot_2 + + p_pivots_sum_0 += tl.sum(probs_blk * larger_mask_0) + p_pivots_sum_1 += tl.sum(probs_blk * larger_mask_1) + p_pivots_sum_2 += tl.sum(probs_blk * larger_mask_2) + + # Check if any of the pivots are equal to k + if tl.abs(p_pivots_sum_0 - p) < 1e-6: + p_pivot = p_pivot_0 + elif tl.abs(p_pivots_sum_1 - p) < 1e-6: + p_pivot = p_pivot_1 + elif tl.abs(p_pivots_sum_2 - p) < 1e-6: + p_pivot = p_pivot_2 + # If none of the pivots are equal to k, we updatae the range + elif p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if num_iters >= 18: + p_pivot = p_pivot_0 # Sixth pass: Apply top-k mask if k_pivot != -float('inf') or p_pivot != -float('inf'): @@ -449,7 +502,7 @@ def triton_apply_top_k_top_p( vocab_size, BLOCK_SIZE, NUM_TILES) # print(f"Output logits: {logits}") # print(f"Output probs: {probs}") - return logits + return logits, probs @torch.compile def compiled_apply_top_k_top_p( @@ -483,24 +536,21 @@ def apply_top_k_top_p( # logits = original_apply_top_k_top_p(logits, k, p) # logits = compiled_apply_top_k_top_p(logits, k, p) - logits = triton_apply_top_k_top_p(logits, k, p) + logits, probs = triton_apply_top_k_top_p(logits, k, p) torch.cuda.synchronize() time_taken = time.time() - start_time print(y_str(f"apply_top_k_top_p done in {time_taken} seconds")) - # if not torch.allclose(probs, original_probs): - # print(r_str("Error: probs are not close")) - # print(f"probs: {probs}") - # print(f"original_probs: {original_probs}") + if not torch.allclose(probs, original_probs): + print(r_str("Error: probs are not close")) + print(f"probs: {probs}") + print(f"original_probs: {original_probs}") if not torch.allclose(logits, original_logits): print(r_str("Error: logits are not close")) print(f"logits: {logits}") print(f"original_logits: {original_logits}") - diff = (logits - original_logits).abs().flatten() - diff_nonzero = diff[diff > 1e-6] - print(f"diff_nonzero: {diff_nonzero}") start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" From 5de5ece6269ee79b510dc8e8019db6d61e018e85 Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 16:48:16 -0700 Subject: [PATCH 14/99] Wrong results Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index a9863cbcc032..83b7dc52ded6 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -435,7 +435,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=avg_logit) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=-float('inf')) larger_mask_0 = probs_blk > p_pivot_0 larger_mask_1 = probs_blk > p_pivot_1 From cbcf7f5245a11cad3e36d96100f6d3506c3efa1d Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 22:44:11 -0700 Subject: [PATCH 15/99] Fixed? Signed-off-by: js_park --- test.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 138 ++++++++++++++++-------- 2 files changed, 97 insertions(+), 43 deletions(-) diff --git a/test.py b/test.py index b335f1a6c5fa..ee5d7e91b7ac 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ "The future of AI is", ] prompts = prompts * 64 -sampling_params = SamplingParams(temperature=0.8, top_k=10, top_p=0.95) +sampling_params = SamplingParams(temperature=0.8, top_k=8, top_p=0.95) llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct") # llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 83b7dc52ded6..787ed190af90 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -163,6 +163,9 @@ def original_apply_top_k_top_p( top_k_mask = logits_sort < top_k_mask logits_sort.masked_fill_(top_k_mask, -float("inf")) + probs = logits_sort.softmax(dim=-1) + probs = probs.scatter(dim=-1, index=logits_idx, src=probs) + if p is not None: # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) @@ -174,7 +177,7 @@ def original_apply_top_k_top_p( # Re-sort the probabilities. logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits + return logits, probs @triton.jit def _topk_kernel(LOGITS, PROBS, K, B, @@ -294,7 +297,8 @@ def _topk_kernel(LOGITS, PROBS, K, B, def _topk_topp_kernel(LOGITS, PROBS, K, P, B, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr): + NUM_TILES: tl.constexpr, + DEBUG_PTR): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): @@ -405,16 +409,18 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = probs_blk - max_probs + probs_blk = tl.where(probs_blk > k_pivot, probs_blk, -float('inf')) + probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) # Fourth pass: Compute probs (softmax) + exp_avg = tl.exp(avg_logit) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=avg_logit) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=exp_avg) probs_blk = probs_blk / sum_exp_logits min_probs = tl.minimum(min_probs, tl.min(probs_blk)) max_probs = tl.maximum(max_probs, tl.max(probs_blk)) @@ -424,60 +430,103 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, min_range = min_probs num_iters = 0 - while p_pivot == -float('inf') and num_iters < 18: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + while p_pivot == -float('inf') and num_iters < 32: + p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range p_pivots_sum_0 = 0.0 - p_pivots_sum_1 = 0.0 - p_pivots_sum_2 = 0.0 + + min_larger_0 = float('inf') + max_smaller_0 = -float('inf') + + # p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + # p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + + # p_pivots_sum_1 = 0.0 + # p_pivots_sum_2 = 0.0 for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, float('inf')) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + masked_smaller_0 = probs_blk * (probs_blk < p_pivot_0) + max_smaller_0 = tl.maximum(max_smaller_0, tl.max(masked_smaller_0)) + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - larger_mask_0 = probs_blk > p_pivot_0 - larger_mask_1 = probs_blk > p_pivot_1 - larger_mask_2 = probs_blk > p_pivot_2 + # larger_mask_1 = probs_blk > p_pivot_1 + # larger_mask_2 = probs_blk > p_pivot_2 - p_pivots_sum_0 += tl.sum(probs_blk * larger_mask_0) - p_pivots_sum_1 += tl.sum(probs_blk * larger_mask_1) - p_pivots_sum_2 += tl.sum(probs_blk * larger_mask_2) + # p_pivots_sum_1 += tl.sum(probs_blk * larger_mask_1) + # p_pivots_sum_2 += tl.sum(probs_blk * larger_mask_2) # Check if any of the pivots are equal to k if tl.abs(p_pivots_sum_0 - p) < 1e-6: p_pivot = p_pivot_0 - elif tl.abs(p_pivots_sum_1 - p) < 1e-6: - p_pivot = p_pivot_1 - elif tl.abs(p_pivots_sum_2 - p) < 1e-6: - p_pivot = p_pivot_2 - # If none of the pivots are equal to k, we updatae the range - elif p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: - min_range = p_pivot_1 elif p_pivots_sum_0 > p: + if p_pivots_sum_0 - min_larger_0 < p: + p_pivot = p_pivot_0 min_range = p_pivot_0 - if p_pivots_sum_0 < p: + elif p_pivots_sum_0 < p: + if p_pivots_sum_0 + max_smaller_0 > p: + + p_pivot = max_smaller_0 max_range = p_pivot_0 - elif p_pivots_sum_1 < p: - max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 + + # elif tl.abs(p_pivots_sum_1 - p) < 1e-6: + # p_pivot = p_pivot_1 + # elif tl.abs(p_pivots_sum_2 - p) < 1e-6: + # p_pivot = p_pivot_2 + # If none of the pivots are equal to k, we updatae the range + # elif p_pivots_sum_2 > p: + # min_range = p_pivot_2 + # elif p_pivots_sum_1 > p: + # min_range = p_pivot_1 + # elif p_pivots_sum_0 > p: + # min_range = p_pivot_0 + # if p_pivots_sum_0 < p: + # max_range = p_pivot_0 + # elif p_pivots_sum_1 < p: + # max_range = p_pivot_1 + # elif p_pivots_sum_2 < p: + # max_range = p_pivot_2 num_iters += 1 - if num_iters >= 18: + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-6: p_pivot = p_pivot_0 - # Sixth pass: Apply top-k mask + + if row_id == 0 and num_iters == 2: + tl.store(DEBUG_PTR + 0 * N + 0, p_pivots_sum_0) + tl.store(DEBUG_PTR + 0 * N + 1, p_pivot_0) + tl.store(DEBUG_PTR + 0 * N + 2, min_probs) + tl.store(DEBUG_PTR + 0 * N + 3, max_probs) + tl.store(DEBUG_PTR + 0 * N + 4, min_range) + tl.store(DEBUG_PTR + 0 * N + 5, max_range) + tl.store(DEBUG_PTR + 0 * N + 6, num_iters) + tl.store(DEBUG_PTR + 0 * N + 7, sum_exp_logits) + tl.store(DEBUG_PTR + 0 * N + 8, p_pivot) + tl.store(DEBUG_PTR + 0 * N + 9, tl.log(p_pivot * sum_exp_logits)) + tl.store(DEBUG_PTR + 0 * N + 10, min_larger_0) + tl.store(DEBUG_PTR + 0 * N + 11, max_smaller_0) + # Subtract a small value to include the nearest smaller value + # If the nearest smaller value very small, it may cause numerical instability + p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - 1e-6 + if row_id == 0: + tl.store(DEBUG_PTR + 12, p_pivot) + tl.store(DEBUG_PTR + 13, p) + + # Transform p_pivot into equivalent logit + # p_pivot = tl.log(p_pivot * sum_exp_logits) + + # Sixth pass: Apply mask if k_pivot != -float('inf') or p_pivot != -float('inf'): + pivot = tl.maximum(k_pivot, p_pivot) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - mask = (logits_blk > k_pivot) & (logits_blk > p_pivot) - logits_blk = tl.where(mask, logits_blk, -float('inf')) + logits_blk = tl.where(logits_blk > pivot, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) def triton_apply_top_k_top_p( @@ -490,6 +539,7 @@ def triton_apply_top_k_top_p( NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE probs = torch.full_like(logits, -float('inf')) + debug = torch.full_like(logits, -float('inf')) print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") @@ -499,7 +549,8 @@ def triton_apply_top_k_top_p( vocab_size, BLOCK_SIZE, NUM_TILES) else: _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES) + vocab_size, BLOCK_SIZE, NUM_TILES, debug) + print(f"debug: {debug[:14, :14]}") # print(f"Output logits: {logits}") # print(f"Output probs: {probs}") return logits, probs @@ -524,9 +575,12 @@ def apply_top_k_top_p( The logits tensor may be updated in-place. """ + logits = torch.full_like(logits, -10.0) + logits[:, :11] = torch.arange(1, 12, dtype=torch.float32, device=logits.device) input_logits = logits.clone() - original_logits = original_apply_top_k_top_p(input_logits, k, p) - original_probs = torch.softmax(input_logits, dim=-1) + print(f"input_logits: {input_logits[:12, :12]}") + original_logits, original_probs = original_apply_top_k_top_p(input_logits, k, p) + # original_probs = torch.softmax(input_logits, dim=-1) batch_size, vocab_size = logits.shape print(g_str("apply_top_k_top_p") + f" logits.shape: {batch_size} x {vocab_size}, p is None: {p is None}, k is None: {k is None}") @@ -544,13 +598,13 @@ def apply_top_k_top_p( if not torch.allclose(probs, original_probs): print(r_str("Error: probs are not close")) - print(f"probs: {probs}") - print(f"original_probs: {original_probs}") + print(f"probs: {probs[:12, :12]}") + print(f"original_probs: {original_probs[:12, :12]}") if not torch.allclose(logits, original_logits): print(r_str("Error: logits are not close")) - print(f"logits: {logits}") - print(f"original_logits: {original_logits}") + print(f"logits: {logits[:12, :12]}") + print(f"original_logits: {original_logits[:12, :12]}") start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" From 643c21d49995541e885701c5801f2370bf7ced5d Mon Sep 17 00:00:00 2001 From: js_park Date: Thu, 25 Sep 2025 22:45:22 -0700 Subject: [PATCH 16/99] Fixed? Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 57 +------------------------ 1 file changed, 2 insertions(+), 55 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 787ed190af90..2856b30e487e 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -437,12 +437,6 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, min_larger_0 = float('inf') max_smaller_0 = -float('inf') - # p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - # p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - - # p_pivots_sum_1 = 0.0 - # p_pivots_sum_2 = 0.0 - for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range @@ -454,12 +448,6 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, max_smaller_0 = tl.maximum(max_smaller_0, tl.max(masked_smaller_0)) p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - # larger_mask_1 = probs_blk > p_pivot_1 - # larger_mask_2 = probs_blk > p_pivot_2 - - # p_pivots_sum_1 += tl.sum(probs_blk * larger_mask_1) - # p_pivots_sum_2 += tl.sum(probs_blk * larger_mask_2) - # Check if any of the pivots are equal to k if tl.abs(p_pivots_sum_0 - p) < 1e-6: p_pivot = p_pivot_0 @@ -473,51 +461,14 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, p_pivot = max_smaller_0 max_range = p_pivot_0 - # elif tl.abs(p_pivots_sum_1 - p) < 1e-6: - # p_pivot = p_pivot_1 - # elif tl.abs(p_pivots_sum_2 - p) < 1e-6: - # p_pivot = p_pivot_2 - # If none of the pivots are equal to k, we updatae the range - # elif p_pivots_sum_2 > p: - # min_range = p_pivot_2 - # elif p_pivots_sum_1 > p: - # min_range = p_pivot_1 - # elif p_pivots_sum_0 > p: - # min_range = p_pivot_0 - # if p_pivots_sum_0 < p: - # max_range = p_pivot_0 - # elif p_pivots_sum_1 < p: - # max_range = p_pivot_1 - # elif p_pivots_sum_2 < p: - # max_range = p_pivot_2 - num_iters += 1 if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-6: p_pivot = p_pivot_0 - - if row_id == 0 and num_iters == 2: - tl.store(DEBUG_PTR + 0 * N + 0, p_pivots_sum_0) - tl.store(DEBUG_PTR + 0 * N + 1, p_pivot_0) - tl.store(DEBUG_PTR + 0 * N + 2, min_probs) - tl.store(DEBUG_PTR + 0 * N + 3, max_probs) - tl.store(DEBUG_PTR + 0 * N + 4, min_range) - tl.store(DEBUG_PTR + 0 * N + 5, max_range) - tl.store(DEBUG_PTR + 0 * N + 6, num_iters) - tl.store(DEBUG_PTR + 0 * N + 7, sum_exp_logits) - tl.store(DEBUG_PTR + 0 * N + 8, p_pivot) - tl.store(DEBUG_PTR + 0 * N + 9, tl.log(p_pivot * sum_exp_logits)) - tl.store(DEBUG_PTR + 0 * N + 10, min_larger_0) - tl.store(DEBUG_PTR + 0 * N + 11, max_smaller_0) + # Transform p_pivot into equivalent logit # Subtract a small value to include the nearest smaller value # If the nearest smaller value very small, it may cause numerical instability p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - 1e-6 - if row_id == 0: - tl.store(DEBUG_PTR + 12, p_pivot) - tl.store(DEBUG_PTR + 13, p) - - # Transform p_pivot into equivalent logit - # p_pivot = tl.log(p_pivot * sum_exp_logits) # Sixth pass: Apply mask if k_pivot != -float('inf') or p_pivot != -float('inf'): @@ -539,7 +490,6 @@ def triton_apply_top_k_top_p( NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE probs = torch.full_like(logits, -float('inf')) - debug = torch.full_like(logits, -float('inf')) print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") @@ -549,8 +499,7 @@ def triton_apply_top_k_top_p( vocab_size, BLOCK_SIZE, NUM_TILES) else: _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES, debug) - print(f"debug: {debug[:14, :14]}") + vocab_size, BLOCK_SIZE, NUM_TILES) # print(f"Output logits: {logits}") # print(f"Output probs: {probs}") return logits, probs @@ -575,8 +524,6 @@ def apply_top_k_top_p( The logits tensor may be updated in-place. """ - logits = torch.full_like(logits, -10.0) - logits[:, :11] = torch.arange(1, 12, dtype=torch.float32, device=logits.device) input_logits = logits.clone() print(f"input_logits: {input_logits[:12, :12]}") original_logits, original_probs = original_apply_top_k_top_p(input_logits, k, p) From 2737c2d7a7337b5ab5d2f3c458c249f13f673ee0 Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 26 Sep 2025 00:23:16 -0700 Subject: [PATCH 17/99] Maybe? Signed-off-by: js_park --- test.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 73 ++++++++++++++++--------- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/test.py b/test.py index ee5d7e91b7ac..7957887fc5d4 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ "The future of AI is", ] prompts = prompts * 64 -sampling_params = SamplingParams(temperature=0.8, top_k=8, top_p=0.95) +sampling_params = SamplingParams(temperature=0.8, top_k=8, top_p=1.0) llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct") # llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 2856b30e487e..a0c24f903e49 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -163,8 +163,6 @@ def original_apply_top_k_top_p( top_k_mask = logits_sort < top_k_mask logits_sort.masked_fill_(top_k_mask, -float("inf")) - probs = logits_sort.softmax(dim=-1) - probs = probs.scatter(dim=-1, index=logits_idx, src=probs) if p is not None: # Apply top-p. @@ -177,7 +175,7 @@ def original_apply_top_k_top_p( # Re-sort the probabilities. logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits, probs + return logits @triton.jit def _topk_kernel(LOGITS, PROBS, K, B, @@ -191,7 +189,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, p_pivot = -float('inf') LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + row_id * N + PROBS_ROW = PROBS + pid * N search_addr = LOGITS_ROW search_range = N @@ -306,7 +304,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, p_pivot = -float('inf') LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + row_id * N + PROBS_ROW = PROBS + pid * N search_addr = LOGITS_ROW search_range = N @@ -430,22 +428,26 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, min_range = min_probs num_iters = 0 + p_pivots_sum_0 = 0.0 while p_pivot == -float('inf') and num_iters < 32: p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range p_pivots_sum_0 = 0.0 - min_larger_0 = float('inf') - max_smaller_0 = -float('inf') + min_larger_0 = 1.0 + max_smaller_0 = 0.0 + second_max_smaller_0 = 0.0 for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, float('inf')) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) masked_smaller_0 = probs_blk * (probs_blk < p_pivot_0) max_smaller_0 = tl.maximum(max_smaller_0, tl.max(masked_smaller_0)) + masked_second_smaller_0 = probs_blk * (probs_blk < max_smaller_0) + second_max_smaller_0 = tl.maximum(second_max_smaller_0, tl.max(masked_second_smaller_0)) p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) # Check if any of the pivots are equal to k @@ -457,18 +459,40 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, min_range = p_pivot_0 elif p_pivots_sum_0 < p: if p_pivots_sum_0 + max_smaller_0 > p: - - p_pivot = max_smaller_0 + p_pivot = second_max_smaller_0 max_range = p_pivot_0 num_iters += 1 if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-6: p_pivot = p_pivot_0 - - # Transform p_pivot into equivalent logit + + if row_id == 0: + tl.store(DEBUG_PTR + num_iters * 17 + 0, p_pivots_sum_0) + tl.store(DEBUG_PTR + num_iters * 17 + 1, p_pivot_0) + tl.store(DEBUG_PTR + num_iters * 17 + 2, min_probs) + tl.store(DEBUG_PTR + num_iters * 17 + 3, max_probs) + tl.store(DEBUG_PTR + num_iters * 17 + 4, min_range) + tl.store(DEBUG_PTR + num_iters * 17 + 5, max_range) + tl.store(DEBUG_PTR + num_iters * 17 + 6, num_iters) + tl.store(DEBUG_PTR + num_iters * 17 + 7, sum_exp_logits) + tl.store(DEBUG_PTR + num_iters * 17 + 8, p_pivot) + tl.store(DEBUG_PTR + num_iters * 17 + 9, tl.log(p_pivot * sum_exp_logits)) + tl.store(DEBUG_PTR + num_iters * 17 + 10, min_larger_0) + tl.store(DEBUG_PTR + num_iters * 17 + 11, max_smaller_0) + tl.store(DEBUG_PTR + num_iters * 17 + 12, second_max_smaller_0) # Subtract a small value to include the nearest smaller value # If the nearest smaller value very small, it may cause numerical instability - p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - 1e-6 + if row_id == 0: + tl.store(DEBUG_PTR + num_iters * 17 + 13, p_pivots_sum_0) + tl.store(DEBUG_PTR + num_iters * 17 + 14, p_pivot) + tl.store(DEBUG_PTR + num_iters * 17 + 15, num_iters) + p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + if row_id == 0: + tl.store(DEBUG_PTR + num_iters * 17 + 16, p_pivot) + tl.store(DEBUG_PTR + num_iters * 17 + 17, p) + + # Transform p_pivot into equivalent logit + # p_pivot = tl.log(p_pivot * sum_exp_logits) # Sixth pass: Apply mask if k_pivot != -float('inf') or p_pivot != -float('inf'): @@ -489,7 +513,8 @@ def triton_apply_top_k_top_p( BLOCK_SIZE = 4096 NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE - probs = torch.full_like(logits, -float('inf')) + probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) + debug = torch.full((32, 18), -float('inf'), device=logits.device) print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") @@ -499,10 +524,11 @@ def triton_apply_top_k_top_p( vocab_size, BLOCK_SIZE, NUM_TILES) else: _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES) + vocab_size, BLOCK_SIZE, NUM_TILES, debug) + print(f"debug: {debug[:, :17]}") # print(f"Output logits: {logits}") # print(f"Output probs: {probs}") - return logits, probs + return logits @torch.compile def compiled_apply_top_k_top_p( @@ -524,9 +550,11 @@ def apply_top_k_top_p( The logits tensor may be updated in-place. """ + # logits = torch.full_like(logits, -10.0) + # logits[:, :11] = torch.arange(1, 12, dtype=torch.float32, device=logits.device) input_logits = logits.clone() print(f"input_logits: {input_logits[:12, :12]}") - original_logits, original_probs = original_apply_top_k_top_p(input_logits, k, p) + original_logits = original_apply_top_k_top_p(input_logits, k, p) # original_probs = torch.softmax(input_logits, dim=-1) batch_size, vocab_size = logits.shape @@ -537,21 +565,16 @@ def apply_top_k_top_p( # logits = original_apply_top_k_top_p(logits, k, p) # logits = compiled_apply_top_k_top_p(logits, k, p) - logits, probs = triton_apply_top_k_top_p(logits, k, p) + logits = triton_apply_top_k_top_p(logits, k, p) torch.cuda.synchronize() time_taken = time.time() - start_time print(y_str(f"apply_top_k_top_p done in {time_taken} seconds")) - if not torch.allclose(probs, original_probs): - print(r_str("Error: probs are not close")) - print(f"probs: {probs[:12, :12]}") - print(f"original_probs: {original_probs[:12, :12]}") - if not torch.allclose(logits, original_logits): print(r_str("Error: logits are not close")) - print(f"logits: {logits[:12, :12]}") - print(f"original_logits: {original_logits[:12, :12]}") + # print(f"logits: {logits[:12, :12]}") + # print(f"original_logits: {original_logits[:12, :12]}") start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" From f24d2e17b9730bea4d31a62defcc80b9423d901b Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 26 Sep 2025 02:17:53 -0700 Subject: [PATCH 18/99] Duplicate logit issues. Signed-off-by: js_park --- test.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 101 ++++++++++++------------ 2 files changed, 51 insertions(+), 52 deletions(-) diff --git a/test.py b/test.py index 7957887fc5d4..990e90313c16 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ "The future of AI is", ] prompts = prompts * 64 -sampling_params = SamplingParams(temperature=0.8, top_k=8, top_p=1.0) +sampling_params = SamplingParams(temperature=0.8, top_k=6, top_p=0.9) llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct") # llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index a0c24f903e49..4d577d6076d4 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -163,7 +163,6 @@ def original_apply_top_k_top_p( top_k_mask = logits_sort < top_k_mask logits_sort.masked_fill_(top_k_mask, -float("inf")) - if p is not None: # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) @@ -398,6 +397,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, p = tl.load(P + row_id) if p != 1.0: + second_max_logit = -float('inf') max_probs = 0.0 min_probs = 1.0 sum_exp_logits = 0.0 @@ -413,15 +413,20 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, sum_exp_logits += tl.sum(probs_blk) tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + second_max_mask = probs_blk * (probs_blk < max_probs) + second_max_logit = tl.maximum(second_max_logit, tl.max(second_max_mask)) + # Fourth pass: Compute probs (softmax) exp_avg = tl.exp(avg_logit) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=exp_avg) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n) probs_blk = probs_blk / sum_exp_logits - min_probs = tl.minimum(min_probs, tl.min(probs_blk)) - max_probs = tl.maximum(max_probs, tl.max(probs_blk)) + min_blk = tl.where(mask_n, probs_blk, 1.0) + min_probs = tl.minimum(min_probs, tl.min(min_blk)) + max_blk = tl.where(mask_n, probs_blk, 0.0) + max_probs = tl.maximum(max_probs, tl.max(max_blk)) tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) max_range = max_probs @@ -434,8 +439,6 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 - max_smaller_0 = 0.0 - second_max_smaller_0 = 0.0 for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -444,59 +447,56 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - masked_smaller_0 = probs_blk * (probs_blk < p_pivot_0) - max_smaller_0 = tl.maximum(max_smaller_0, tl.max(masked_smaller_0)) - masked_second_smaller_0 = probs_blk * (probs_blk < max_smaller_0) - second_max_smaller_0 = tl.maximum(second_max_smaller_0, tl.max(masked_second_smaller_0)) p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) # Check if any of the pivots are equal to k - if tl.abs(p_pivots_sum_0 - p) < 1e-6: - p_pivot = p_pivot_0 - elif p_pivots_sum_0 > p: + if p_pivots_sum_0 >= p: if p_pivots_sum_0 - min_larger_0 < p: p_pivot = p_pivot_0 - min_range = p_pivot_0 - elif p_pivots_sum_0 < p: - if p_pivots_sum_0 + max_smaller_0 > p: - p_pivot = second_max_smaller_0 + elif tl.abs(p_pivots_sum_0 - min_larger_0) < 1e-6: + p_pivot = (p_pivot_0 + min_larger_0) / 2.0 + else: + min_range = p_pivot_0 + else: max_range = p_pivot_0 num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-6: + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-9: p_pivot = p_pivot_0 - if row_id == 0: - tl.store(DEBUG_PTR + num_iters * 17 + 0, p_pivots_sum_0) - tl.store(DEBUG_PTR + num_iters * 17 + 1, p_pivot_0) - tl.store(DEBUG_PTR + num_iters * 17 + 2, min_probs) - tl.store(DEBUG_PTR + num_iters * 17 + 3, max_probs) - tl.store(DEBUG_PTR + num_iters * 17 + 4, min_range) - tl.store(DEBUG_PTR + num_iters * 17 + 5, max_range) - tl.store(DEBUG_PTR + num_iters * 17 + 6, num_iters) - tl.store(DEBUG_PTR + num_iters * 17 + 7, sum_exp_logits) - tl.store(DEBUG_PTR + num_iters * 17 + 8, p_pivot) - tl.store(DEBUG_PTR + num_iters * 17 + 9, tl.log(p_pivot * sum_exp_logits)) - tl.store(DEBUG_PTR + num_iters * 17 + 10, min_larger_0) - tl.store(DEBUG_PTR + num_iters * 17 + 11, max_smaller_0) - tl.store(DEBUG_PTR + num_iters * 17 + 12, second_max_smaller_0) + if row_id == 1: + tl.store(DEBUG_PTR + num_iters * 18 + 0, p_pivots_sum_0) + tl.store(DEBUG_PTR + num_iters * 18 + 1, p_pivot_0) + tl.store(DEBUG_PTR + num_iters * 18 + 2, min_probs) + tl.store(DEBUG_PTR + num_iters * 18 + 3, max_probs) + tl.store(DEBUG_PTR + num_iters * 18 + 4, min_range) + tl.store(DEBUG_PTR + num_iters * 18 + 5, max_range) + tl.store(DEBUG_PTR + num_iters * 18 + 6, num_iters) + tl.store(DEBUG_PTR + num_iters * 18 + 7, sum_exp_logits) + tl.store(DEBUG_PTR + num_iters * 18 + 8, p_pivot) + tl.store(DEBUG_PTR + num_iters * 18 + 9, tl.log(p_pivot * sum_exp_logits)) + tl.store(DEBUG_PTR + num_iters * 18 + 10, min_larger_0) # Subtract a small value to include the nearest smaller value # If the nearest smaller value very small, it may cause numerical instability - if row_id == 0: - tl.store(DEBUG_PTR + num_iters * 17 + 13, p_pivots_sum_0) - tl.store(DEBUG_PTR + num_iters * 17 + 14, p_pivot) - tl.store(DEBUG_PTR + num_iters * 17 + 15, num_iters) + if row_id == 1: + tl.store(DEBUG_PTR + num_iters * 18 + 13, p_pivots_sum_0) + tl.store(DEBUG_PTR + num_iters * 18 + 14, p_pivot) + tl.store(DEBUG_PTR + num_iters * 18 + 15, num_iters) p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - if row_id == 0: - tl.store(DEBUG_PTR + num_iters * 17 + 16, p_pivot) - tl.store(DEBUG_PTR + num_iters * 17 + 17, p) + if row_id == 1: + tl.store(DEBUG_PTR + num_iters * 18 + 16, p_pivot) + tl.store(DEBUG_PTR + num_iters * 18 + 17, p) + + # At least one value should be greater than p_pivot + if p_pivot >= max_logit: + p_pivot = second_max_logit # Transform p_pivot into equivalent logit # p_pivot = tl.log(p_pivot * sum_exp_logits) # Sixth pass: Apply mask - if k_pivot != -float('inf') or p_pivot != -float('inf'): - pivot = tl.maximum(k_pivot, p_pivot) + pivot = tl.maximum(k_pivot, p_pivot) + if pivot != -float('inf'): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N @@ -514,10 +514,10 @@ def triton_apply_top_k_top_p( NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) - debug = torch.full((32, 18), -float('inf'), device=logits.device) - print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " - f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " - f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") + debug = torch.full((20, 18), -float('inf'), device=logits.device) + # print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " + # f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " + # f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") # print(f"Input logits: {logits}") if p is None and k is not None: _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, @@ -525,7 +525,7 @@ def triton_apply_top_k_top_p( else: _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES, debug) - print(f"debug: {debug[:, :17]}") + print(f"debug: {debug[:, :18]}") # print(f"Output logits: {logits}") # print(f"Output probs: {probs}") return logits @@ -550,10 +550,9 @@ def apply_top_k_top_p( The logits tensor may be updated in-place. """ - # logits = torch.full_like(logits, -10.0) - # logits[:, :11] = torch.arange(1, 12, dtype=torch.float32, device=logits.device) + input_logits = logits.clone() - print(f"input_logits: {input_logits[:12, :12]}") + print(f"input_logits: {torch.sort(input_logits, descending=True).values[:12, :12]}") original_logits = original_apply_top_k_top_p(input_logits, k, p) # original_probs = torch.softmax(input_logits, dim=-1) @@ -573,8 +572,8 @@ def apply_top_k_top_p( if not torch.allclose(logits, original_logits): print(r_str("Error: logits are not close")) - # print(f"logits: {logits[:12, :12]}") - # print(f"original_logits: {original_logits[:12, :12]}") + print(f"logits: {torch.sort(logits, descending=True).values[:12, :12]}") + print(f"original_logits: {torch.sort(original_logits, descending=True).values[:12, :12]}") start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" From a58ca6cfa07ec71ad1b6ad6bfa1d4da6fe297c12 Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 26 Sep 2025 02:23:14 -0700 Subject: [PATCH 19/99] Duplicate logit issues. Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 4d577d6076d4..d771a851c395 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -453,8 +453,8 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, if p_pivots_sum_0 >= p: if p_pivots_sum_0 - min_larger_0 < p: p_pivot = p_pivot_0 - elif tl.abs(p_pivots_sum_0 - min_larger_0) < 1e-6: - p_pivot = (p_pivot_0 + min_larger_0) / 2.0 + if tl.abs(p_pivots_sum_0 - min_larger_0) < 1e-6: + p_pivot = (p_pivot_0 + min_larger_0) / 2.0 else: min_range = p_pivot_0 else: From b87c09548792efe02af108359a3f7217f73b505b Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 26 Sep 2025 15:14:59 -0700 Subject: [PATCH 20/99] Top-p duplicate handler implemented Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 90 +++++++++++++++++-------- 1 file changed, 61 insertions(+), 29 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index d771a851c395..67ef0b9ae230 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -291,7 +291,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, @triton.jit -def _topk_topp_kernel(LOGITS, PROBS, K, P, B, +def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr, @@ -303,7 +303,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, p_pivot = -float('inf') LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + pid * N + PROBS_ROW = PROBS + row_id * N search_addr = LOGITS_ROW search_range = N @@ -312,6 +312,10 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, max_logit = -float('inf') avg_logit = -float('inf') + # The Pytorch version removes the earlier duplicates if there are more than one duplicates + force_remove_logit = -float('inf') + num_force_remove = tl.zeros((), dtype=tl.uint32) + k = tl.load(K + row_id) if not (k == N): # All tokens are valid min_logit = float('inf') @@ -434,11 +438,14 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, num_iters = 0 p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) while p_pivot == -float('inf') and num_iters < 32: p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -447,14 +454,14 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-6) + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) # Check if any of the pivots are equal to k if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - min_larger_0 < p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: p_pivot = p_pivot_0 - if tl.abs(p_pivots_sum_0 - min_larger_0) < 1e-6: - p_pivot = (p_pivot_0 + min_larger_0) / 2.0 else: min_range = p_pivot_0 else: @@ -465,44 +472,67 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, p_pivot = p_pivot_0 if row_id == 1: - tl.store(DEBUG_PTR + num_iters * 18 + 0, p_pivots_sum_0) - tl.store(DEBUG_PTR + num_iters * 18 + 1, p_pivot_0) - tl.store(DEBUG_PTR + num_iters * 18 + 2, min_probs) - tl.store(DEBUG_PTR + num_iters * 18 + 3, max_probs) - tl.store(DEBUG_PTR + num_iters * 18 + 4, min_range) - tl.store(DEBUG_PTR + num_iters * 18 + 5, max_range) - tl.store(DEBUG_PTR + num_iters * 18 + 6, num_iters) - tl.store(DEBUG_PTR + num_iters * 18 + 7, sum_exp_logits) - tl.store(DEBUG_PTR + num_iters * 18 + 8, p_pivot) - tl.store(DEBUG_PTR + num_iters * 18 + 9, tl.log(p_pivot * sum_exp_logits)) - tl.store(DEBUG_PTR + num_iters * 18 + 10, min_larger_0) + tl.store(DEBUG_PTR + num_iters * 21 + 0, p_pivots_sum_0) + tl.store(DEBUG_PTR + num_iters * 21 + 1, p_pivot_0) + tl.store(DEBUG_PTR + num_iters * 21 + 2, min_probs) + tl.store(DEBUG_PTR + num_iters * 21 + 3, max_probs) + tl.store(DEBUG_PTR + num_iters * 21 + 4, min_range) + tl.store(DEBUG_PTR + num_iters * 21 + 5, max_range) + tl.store(DEBUG_PTR + num_iters * 21 + 6, num_iters) + tl.store(DEBUG_PTR + num_iters * 21 + 7, sum_exp_logits) + tl.store(DEBUG_PTR + num_iters * 21 + 8, p_pivot) + tl.store(DEBUG_PTR + num_iters * 21 + 9, tl.log(p_pivot * sum_exp_logits)) + tl.store(DEBUG_PTR + num_iters * 21 + 10, min_larger_0) + tl.store(DEBUG_PTR + num_iters * 21 + 11, num_min_larger_0) # Subtract a small value to include the nearest smaller value # If the nearest smaller value very small, it may cause numerical instability - if row_id == 1: - tl.store(DEBUG_PTR + num_iters * 18 + 13, p_pivots_sum_0) - tl.store(DEBUG_PTR + num_iters * 18 + 14, p_pivot) - tl.store(DEBUG_PTR + num_iters * 18 + 15, num_iters) - p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - if row_id == 1: - tl.store(DEBUG_PTR + num_iters * 18 + 16, p_pivot) - tl.store(DEBUG_PTR + num_iters * 18 + 17, p) + # At least one value should be greater than p_pivot if p_pivot >= max_logit: p_pivot = second_max_logit + elif num_min_larger_0 > 1: + # Force remove duplicates (p_pivot is made to include all duplicates if it falls on the duplicates) + num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, tl.uint32) + force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit + + if row_id == 1: + tl.store(DEBUG_PTR + num_iters * 21 + 12, p_pivots_sum_0) + tl.store(DEBUG_PTR + num_iters * 21 + 13, p_pivot) + tl.store(DEBUG_PTR + num_iters * 21 + 14, num_iters) + p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + if row_id == 1: + tl.store(DEBUG_PTR + num_iters * 21 + 15, p_pivot) + tl.store(DEBUG_PTR + num_iters * 21 + 16, p) + tl.store(DEBUG_PTR + num_iters * 21 + 17, force_remove_logit) + tl.store(DEBUG_PTR + num_iters * 21 + 18, num_force_remove) + tl.store(DEBUG_PTR + num_iters * 21 + 19, num_min_larger_0) + tl.store(DEBUG_PTR + num_iters * 21 + 20, min_larger_0) + # Transform p_pivot into equivalent logit # p_pivot = tl.log(p_pivot * sum_exp_logits) # Sixth pass: Apply mask pivot = tl.maximum(k_pivot, p_pivot) + current_num_force_remove = tl.zeros((), dtype=tl.uint32) if pivot != -float('inf'): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + + if force_remove_logit != -float('inf'): + force_remove_mask = tl.abs(logits_blk - force_remove_logit) < 1e-6 + force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove + force_remove_count_mask = force_remove_count <= num_force_remove + force_remove_mask = force_remove_count_mask & force_remove_mask + logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) + current_num_force_remove = tl.max(force_remove_count) + logits_blk = tl.where(logits_blk > pivot, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + def triton_apply_top_k_top_p( logits: torch.Tensor, @@ -513,19 +543,21 @@ def triton_apply_top_k_top_p( BLOCK_SIZE = 4096 NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE - probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) - debug = torch.full((20, 18), -float('inf'), device=logits.device) + debug = torch.full((20, 21), -float('inf'), device=logits.device) # print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " # f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " # f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") # print(f"Input logits: {logits}") if p is None and k is not None: + probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES) else: - _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + probs = torch.full_like(logits, -float('inf'), device=logits.device) + num_search = torch.full((logits.shape[0],), vocab_size, device=logits.device) + _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, num_search, k, p, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES, debug) - print(f"debug: {debug[:, :18]}") + print(f"debug: {debug}") # print(f"Output logits: {logits}") # print(f"Output probs: {probs}") return logits From 6e3ca0a301b98397e371be8bcc23455c761cfb7e Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 26 Sep 2025 15:59:09 -0700 Subject: [PATCH 21/99] Top-p fixed Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 33 ++++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 67ef0b9ae230..786e51feae08 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -454,9 +454,15 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-6) p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-6) # Check if any of the pivots are equal to k if p_pivots_sum_0 >= p: @@ -471,7 +477,7 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-9: p_pivot = p_pivot_0 - if row_id == 1: + if row_id == 195: tl.store(DEBUG_PTR + num_iters * 21 + 0, p_pivots_sum_0) tl.store(DEBUG_PTR + num_iters * 21 + 1, p_pivot_0) tl.store(DEBUG_PTR + num_iters * 21 + 2, min_probs) @@ -480,8 +486,8 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, tl.store(DEBUG_PTR + num_iters * 21 + 5, max_range) tl.store(DEBUG_PTR + num_iters * 21 + 6, num_iters) tl.store(DEBUG_PTR + num_iters * 21 + 7, sum_exp_logits) - tl.store(DEBUG_PTR + num_iters * 21 + 8, p_pivot) - tl.store(DEBUG_PTR + num_iters * 21 + 9, tl.log(p_pivot * sum_exp_logits)) + tl.store(DEBUG_PTR + num_iters * 21 + 8, p_pivot_0) + tl.store(DEBUG_PTR + num_iters * 21 + 9, tl.log(p_pivot_0 * sum_exp_logits) + max_logit) tl.store(DEBUG_PTR + num_iters * 21 + 10, min_larger_0) tl.store(DEBUG_PTR + num_iters * 21 + 11, num_min_larger_0) # Subtract a small value to include the nearest smaller value @@ -496,12 +502,12 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, tl.uint32) force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit - if row_id == 1: + if row_id == 195: tl.store(DEBUG_PTR + num_iters * 21 + 12, p_pivots_sum_0) tl.store(DEBUG_PTR + num_iters * 21 + 13, p_pivot) tl.store(DEBUG_PTR + num_iters * 21 + 14, num_iters) p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - if row_id == 1: + if row_id == 195: tl.store(DEBUG_PTR + num_iters * 21 + 15, p_pivot) tl.store(DEBUG_PTR + num_iters * 21 + 16, p) tl.store(DEBUG_PTR + num_iters * 21 + 17, force_remove_logit) @@ -584,7 +590,9 @@ def apply_top_k_top_p( """ input_logits = logits.clone() - print(f"input_logits: {torch.sort(input_logits, descending=True).values[:12, :12]}") + if input_logits.shape[0] > 195: + print(f"input_logits: {torch.sort(input_logits[195], descending=True).values[:12]}") + original_logits = original_apply_top_k_top_p(input_logits, k, p) # original_probs = torch.softmax(input_logits, dim=-1) @@ -604,8 +612,15 @@ def apply_top_k_top_p( if not torch.allclose(logits, original_logits): print(r_str("Error: logits are not close")) - print(f"logits: {torch.sort(logits, descending=True).values[:12, :12]}") - print(f"original_logits: {torch.sort(original_logits, descending=True).values[:12, :12]}") + error_rows = torch.where(logits != original_logits)[0] + error_rows = torch.unique(error_rows) + num_error_rows = error_rows.shape[0] + print(f"num_error_rows: {num_error_rows} - {error_rows}") + row_to_show = 12 if num_error_rows > 12 else num_error_rows + print(f"logits: {torch.sort(logits[error_rows], descending=True).values[:row_to_show, :12]}") + print(f"original_logits: {torch.sort(original_logits[error_rows], descending=True).values[:row_to_show, :12]}") + + print("////////////////////////////////////////////////////////////") start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" From 034e8024fc572941eb0b2dbc927bec9f33d110d6 Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 26 Sep 2025 16:44:19 -0700 Subject: [PATCH 22/99] Need to implement topp-only, topk and topk-topp works. Signed-off-by: js_park --- test.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 300 ++++++++++++++++++------ 2 files changed, 225 insertions(+), 77 deletions(-) diff --git a/test.py b/test.py index 990e90313c16..00e4b0fbfd34 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ "The future of AI is", ] prompts = prompts * 64 -sampling_params = SamplingParams(temperature=0.8, top_k=6, top_p=0.9) +sampling_params = SamplingParams(temperature=0.8, top_k=31, top_p=0.97) llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct") # llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 786e51feae08..0aa30b9985c4 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -185,7 +185,6 @@ def _topk_kernel(LOGITS, PROBS, K, B, num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): k_pivot = -float('inf') - p_pivot = -float('inf') LOGITS_ROW = LOGITS + row_id * N PROBS_ROW = PROBS + pid * N @@ -209,7 +208,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, sq_avg_logit = tl.sum(logits_blk * logits_blk) / N std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - outlier_pivot = avg_logit + 2.8 * std_logit + outlier_pivot = avg_logit + 2.5 * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits and gather outliers for i in range(0,search_iters): @@ -280,22 +279,21 @@ def _topk_kernel(LOGITS, PROBS, K, B, k_pivot = k_pivot_0 # Third pass: Apply top-k mask - if k_pivot != -float('inf') or p_pivot != -float('inf'): + if k_pivot != -float('inf'): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - mask = (logits_blk > k_pivot) & (logits_blk > p_pivot) + mask = (logits_blk > k_pivot) logits_blk = tl.where(mask, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) @triton.jit -def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, +def _topp_kernel(LOGITS, PROBS, P, B, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr, - DEBUG_PTR): + NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): @@ -310,42 +308,218 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, search_iters = NUM_TILES max_logit = -float('inf') + min_logit = float('inf') avg_logit = -float('inf') # The Pytorch version removes the earlier duplicates if there are more than one duplicates force_remove_logit = -float('inf') num_force_remove = tl.zeros((), dtype=tl.uint32) - k = tl.load(K + row_id) - if not (k == N): # All tokens are valid - min_logit = float('inf') + # Zeroth pass: Compute avg and std from a sample block + # May produce incorrect results if N < BLOCK_SIZE + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < N + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / N + sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + + outlier_pivot = avg_logit + 2.5 * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) + # First pass: compute max and min logits and gather outliers + for i in range(0,search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = (logits_blk > outlier_pivot) & mask_n + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) + + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_addr = PROBS_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if N < BLOCK_SIZE - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < N - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / N - sq_avg_logit = tl.sum(logits_blk * logits_blk) / N - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - outlier_pivot = avg_logit + 2.8 * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - # First pass: compute max and min logits and gather outliers - for i in range(0,search_iters): + p = tl.load(P + row_id) + if p != 1.0: + second_max_logit = -float('inf') + max_probs = 0.0 + min_probs = 1.0 + sum_exp_logits = 0.0 + + # Third pass: Compute exp logits and sum + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.where(probs_blk > k_pivot, probs_blk, -float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) + second_max_mask = probs_blk * (probs_blk < max_probs) + second_max_logit = tl.maximum(second_max_logit, tl.max(second_max_mask)) + + # Fourth pass: Compute probs (softmax) + exp_avg = tl.exp(avg_logit) + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n) + probs_blk = probs_blk / sum_exp_logits + min_blk = tl.where(mask_n, probs_blk, 1.0) + min_probs = tl.minimum(min_probs, tl.min(min_blk)) + max_blk = tl.where(mask_n, probs_blk, 0.0) + max_probs = tl.maximum(max_probs, tl.max(max_blk)) + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + + max_range = max_probs + min_range = min_probs + + num_iters = 0 + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + # Fifth passes: Search for p_pivot (2log_2(n)) + while p_pivot == -float('inf') and num_iters < 32: + p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range + p_pivots_sum_0 = 0.0 + + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-6) + + # Check if any of the pivots are equal to k + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + else: + min_range = p_pivot_0 + else: + max_range = p_pivot_0 + + num_iters += 1 + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-9: + p_pivot = p_pivot_0 + + # At least one value should be greater than p_pivot + if p_pivot >= max_logit: + p_pivot = second_max_logit + elif num_min_larger_0 > 1: + # Force remove duplicates (p_pivot is made to include all duplicates if it falls on the duplicates) + num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, tl.uint32) + force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit + + p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + + # Sixth pass: Apply mask + pivot = tl.maximum(k_pivot, p_pivot) + current_num_force_remove = tl.zeros((), dtype=tl.uint32) + if pivot != -float('inf'): + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + + if force_remove_logit != -float('inf'): + force_remove_mask = tl.abs(logits_blk - force_remove_logit) < 1e-6 + force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove + force_remove_count_mask = force_remove_count <= num_force_remove + force_remove_mask = force_remove_count_mask & force_remove_mask + logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) + current_num_force_remove = tl.max(force_remove_count) + + logits_blk = tl.where(logits_blk > pivot, logits_blk, -float('inf')) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + + +@triton.jit +def _topk_topp_kernel(LOGITS, PROBS, K, P, B, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + for row_id in tl.range(pid, B, num_programs): + k_pivot = -float('inf') + p_pivot = -float('inf') + + LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + row_id * N + + search_addr = LOGITS_ROW + search_range = N + search_iters = NUM_TILES + + max_logit = -float('inf') + min_logit = float('inf') + avg_logit = -float('inf') + + # The Pytorch version removes the earlier duplicates if there are more than one duplicates + force_remove_logit = -float('inf') + num_force_remove = tl.zeros((), dtype=tl.uint32) + + # Zeroth pass: Compute avg and std from a sample block + # May produce incorrect results if N < BLOCK_SIZE + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < N + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / N + sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + + outlier_pivot = avg_logit + 2.5 * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) + # First pass: compute max and min logits and gather outliers + for i in range(0,search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = (logits_blk > outlier_pivot) & mask_n + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) + + ############### START OF TOP-K CODE ############### + + ##### THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K IMPLEMENTATION + ##### CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS + ##### IF YOU NEED EXECATLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION + ##### AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE + k = tl.load(K + row_id) + if not (k == N): # All tokens are valid max_range = max_logit min_range = min_logit if num_outliers > k: @@ -397,7 +571,11 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, num_iters += 1 if num_iters >= 18: - k_pivot = k_pivot_0 + k_pivot = k_pivot_0 + + ############### END OF TOP-K CODE ############### + + ############### START OF TOP-P CODE ############### p = tl.load(P + row_id) if p != 1.0: @@ -440,6 +618,8 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + # Fifth passes: Search for p_pivot (2log_2(n)) while p_pivot == -float('inf') and num_iters < 32: p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range p_pivots_sum_0 = 0.0 @@ -477,23 +657,6 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-9: p_pivot = p_pivot_0 - if row_id == 195: - tl.store(DEBUG_PTR + num_iters * 21 + 0, p_pivots_sum_0) - tl.store(DEBUG_PTR + num_iters * 21 + 1, p_pivot_0) - tl.store(DEBUG_PTR + num_iters * 21 + 2, min_probs) - tl.store(DEBUG_PTR + num_iters * 21 + 3, max_probs) - tl.store(DEBUG_PTR + num_iters * 21 + 4, min_range) - tl.store(DEBUG_PTR + num_iters * 21 + 5, max_range) - tl.store(DEBUG_PTR + num_iters * 21 + 6, num_iters) - tl.store(DEBUG_PTR + num_iters * 21 + 7, sum_exp_logits) - tl.store(DEBUG_PTR + num_iters * 21 + 8, p_pivot_0) - tl.store(DEBUG_PTR + num_iters * 21 + 9, tl.log(p_pivot_0 * sum_exp_logits) + max_logit) - tl.store(DEBUG_PTR + num_iters * 21 + 10, min_larger_0) - tl.store(DEBUG_PTR + num_iters * 21 + 11, num_min_larger_0) - # Subtract a small value to include the nearest smaller value - # If the nearest smaller value very small, it may cause numerical instability - - # At least one value should be greater than p_pivot if p_pivot >= max_logit: p_pivot = second_max_logit @@ -502,22 +665,9 @@ def _topk_topp_kernel(LOGITS, PROBS, NUM_SEARCH, K, P, B, num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, tl.uint32) force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit - if row_id == 195: - tl.store(DEBUG_PTR + num_iters * 21 + 12, p_pivots_sum_0) - tl.store(DEBUG_PTR + num_iters * 21 + 13, p_pivot) - tl.store(DEBUG_PTR + num_iters * 21 + 14, num_iters) p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - if row_id == 195: - tl.store(DEBUG_PTR + num_iters * 21 + 15, p_pivot) - tl.store(DEBUG_PTR + num_iters * 21 + 16, p) - tl.store(DEBUG_PTR + num_iters * 21 + 17, force_remove_logit) - tl.store(DEBUG_PTR + num_iters * 21 + 18, num_force_remove) - tl.store(DEBUG_PTR + num_iters * 21 + 19, num_min_larger_0) - tl.store(DEBUG_PTR + num_iters * 21 + 20, min_larger_0) - - - # Transform p_pivot into equivalent logit - # p_pivot = tl.log(p_pivot * sum_exp_logits) + + ############### END OF TOP-P CODE ############### # Sixth pass: Apply mask pivot = tl.maximum(k_pivot, p_pivot) @@ -554,16 +704,18 @@ def triton_apply_top_k_top_p( # f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " # f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") # print(f"Input logits: {logits}") - if p is None and k is not None: - probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) + + probs = torch.full_like(logits, -float('inf'), device=logits.device) + if k is not None and p is None: _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, vocab_size, BLOCK_SIZE, NUM_TILES) - else: - probs = torch.full_like(logits, -float('inf'), device=logits.device) - num_search = torch.full((logits.shape[0],), vocab_size, device=logits.device) - _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, num_search, k, p, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES, debug) - print(f"debug: {debug}") + elif k is None and p is not None: + _topp_kernel[(NUM_PROGRAMS,)](logits, probs, p, batch_size, + vocab_size, BLOCK_SIZE, NUM_TILES) + elif k is not None and p is not None: + _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + vocab_size, BLOCK_SIZE, NUM_TILES) + # print(f"debug: {debug}") # print(f"Output logits: {logits}") # print(f"Output probs: {probs}") return logits @@ -590,9 +742,6 @@ def apply_top_k_top_p( """ input_logits = logits.clone() - if input_logits.shape[0] > 195: - print(f"input_logits: {torch.sort(input_logits[195], descending=True).values[:12]}") - original_logits = original_apply_top_k_top_p(input_logits, k, p) # original_probs = torch.softmax(input_logits, dim=-1) @@ -620,7 +769,6 @@ def apply_top_k_top_p( print(f"logits: {torch.sort(logits[error_rows], descending=True).values[:row_to_show, :12]}") print(f"original_logits: {torch.sort(original_logits[error_rows], descending=True).values[:row_to_show, :12]}") - print("////////////////////////////////////////////////////////////") start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" From 111458203f3ad53045a05ec612819d8e4bbaf7a0 Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 26 Sep 2025 23:26:39 -0700 Subject: [PATCH 23/99] Correctness tested for top-p. Duplication handling for top-k remaining. Signed-off-by: js_park --- test.py | 14 +- vllm/v1/sample/ops/topk_topp_sampler.py | 289 +++++++++++++----------- 2 files changed, 164 insertions(+), 139 deletions(-) diff --git a/test.py b/test.py index 00e4b0fbfd34..5165e302d19f 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,12 @@ +import os from vllm import LLM, SamplingParams +# For V1: Turn off multiprocessing to make scheduling deterministic +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + +# Set a fixed seed for reproducibility +SEED = 42 + prompts = [ "Hello, my name is", "The president of the United States is", @@ -7,9 +14,12 @@ "The future of AI is", ] prompts = prompts * 64 -sampling_params = SamplingParams(temperature=0.8, top_k=31, top_p=0.97) -llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct") +# Add seed to sampling parameters for deterministic sampling +sampling_params = SamplingParams(temperature=0.8, top_p=0.7, seed=SEED) + +# Add seed to LLM initialization for global reproducibility +llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", seed=SEED) # llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 0aa30b9985c4..0cce44be0b23 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -169,7 +169,7 @@ def original_apply_top_k_top_p( probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) # at least one - top_p_mask[:, -1] = False + top_p_mask[:12, -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. @@ -178,25 +178,31 @@ def original_apply_top_k_top_p( @triton.jit def _topk_kernel(LOGITS, PROBS, K, B, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr): + SIGMA: tl.constexpr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): - k_pivot = -float('inf') + k = tl.load(K + row_id) + if not (k == N): # All tokens are valid - LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + pid * N + # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K IMPLEMENTATION + # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS + # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION + # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE - search_addr = LOGITS_ROW - search_range = N - search_iters = NUM_TILES + k_pivot = -float('inf') - max_logit = -float('inf') + LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + pid * N - k = tl.load(K + row_id) - if not (k == N): # All tokens are valid + search_addr = LOGITS_ROW + search_range = N + search_iters = NUM_TILES + + max_logit = -float('inf') min_logit = float('inf') # Zeroth pass: Compute avg and std from a sample block @@ -208,7 +214,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, sq_avg_logit = tl.sum(logits_blk * logits_blk) / N std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - outlier_pivot = avg_logit + 2.5 * std_logit + outlier_pivot = avg_logit + SIGMA * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits and gather outliers for i in range(0,search_iters): @@ -278,113 +284,109 @@ def _topk_kernel(LOGITS, PROBS, K, B, if num_iters >= 18: k_pivot = k_pivot_0 - # Third pass: Apply top-k mask - if k_pivot != -float('inf'): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - mask = (logits_blk > k_pivot) - logits_blk = tl.where(mask, logits_blk, -float('inf')) - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + # Third pass: Apply top-k mask + if k_pivot != -float('inf'): + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) + mask = (logits_blk > k_pivot) + logits_blk = tl.where(mask, logits_blk, -float('inf')) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) @triton.jit -def _topp_kernel(LOGITS, PROBS, P, B, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr): +def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, + SIGMA: tl.constexpr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): - k_pivot = -float('inf') - p_pivot = -float('inf') - - LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + row_id * N - - search_addr = LOGITS_ROW - search_range = N - search_iters = NUM_TILES - - max_logit = -float('inf') - min_logit = float('inf') - avg_logit = -float('inf') + p = tl.load(P + row_id) + if p != 1.0: # All tokens are valid - # The Pytorch version removes the earlier duplicates if there are more than one duplicates - force_remove_logit = -float('inf') - num_force_remove = tl.zeros((), dtype=tl.uint32) + p_pivot = -float('inf') - # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if N < BLOCK_SIZE - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < N - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / N - sq_avg_logit = tl.sum(logits_blk * logits_blk) / N - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + row_id * N + PROBS_2_ROW = PROBS_2 + row_id * N - outlier_pivot = avg_logit + 2.5 * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - # First pass: compute max and min logits and gather outliers - for i in range(0,search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + search_addr = PROBS_ROW + search_range = N + search_iters = NUM_TILES - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) + max_logit = -float('inf') + min_logit = float('inf') - if num_outliers > k: - max_range = max_logit - min_range = outlier_pivot - search_addr = PROBS_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + # The Pytorch version removes the earlier duplicates if there are more than one duplicates + force_remove_logit = -float('inf') + num_force_remove = tl.zeros((), dtype=tl.uint32) + + # Zeroth pass: Compute avg and std from a sample block + # May produce incorrect results if N < BLOCK_SIZE OR all logits are the same + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < N + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / N + sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + outlier_pivot = avg_logit + SIGMA * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) + sum_outlier_probs = 0.0 - p = tl.load(P + row_id) - if p != 1.0: - second_max_logit = -float('inf') - max_probs = 0.0 - min_probs = 1.0 sum_exp_logits = 0.0 - - # Third pass: Compute exp logits and sum + + # First pass: compute max and min logits + for i in range(0,search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + # Second pass: Calculate exp logits and sum for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = tl.where(probs_blk > k_pivot, probs_blk, -float('inf')) + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - second_max_mask = probs_blk * (probs_blk < max_probs) - second_max_logit = tl.maximum(second_max_logit, tl.max(second_max_mask)) + outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits - # Fourth pass: Compute probs (softmax) - exp_avg = tl.exp(avg_logit) + # Third pass: Calculate probs and get outliers for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n) + + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) probs_blk = probs_blk / sum_exp_logits - min_blk = tl.where(mask_n, probs_blk, 1.0) - min_probs = tl.minimum(min_probs, tl.min(min_blk)) - max_blk = tl.where(mask_n, probs_blk, 0.0) - max_probs = tl.maximum(max_probs, tl.max(max_blk)) tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) - max_range = max_probs - min_range = min_probs + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + if sum_outlier_probs > p: + min_range = outlier_prob + search_addr = PROBS_2_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + + second_max_logit = -float('inf') num_iters = 0 p_pivots_sum_0 = 0.0 @@ -402,7 +404,7 @@ def _topp_kernel(LOGITS, PROBS, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) @@ -412,9 +414,9 @@ def _topp_kernel(LOGITS, PROBS, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-6) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) # Check if any of the pivots are equal to k if p_pivots_sum_0 >= p: @@ -426,7 +428,7 @@ def _topp_kernel(LOGITS, PROBS, P, B, max_range = p_pivot_0 num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-9: + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: p_pivot = p_pivot_0 # At least one value should be greater than p_pivot @@ -438,30 +440,30 @@ def _topp_kernel(LOGITS, PROBS, P, B, force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - - # Sixth pass: Apply mask - pivot = tl.maximum(k_pivot, p_pivot) - current_num_force_remove = tl.zeros((), dtype=tl.uint32) - if pivot != -float('inf'): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - if force_remove_logit != -float('inf'): - force_remove_mask = tl.abs(logits_blk - force_remove_logit) < 1e-6 - force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove - force_remove_count_mask = force_remove_count <= num_force_remove - force_remove_mask = force_remove_count_mask & force_remove_mask - logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) - current_num_force_remove = tl.max(force_remove_count) - - logits_blk = tl.where(logits_blk > pivot, logits_blk, -float('inf')) - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + # Sixth pass: Apply mask + current_num_force_remove = tl.zeros((), dtype=tl.uint32) + if p_pivot != -float('inf'): + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + + if force_remove_logit != -float('inf'): + force_remove_mask = tl.abs(logits_blk - force_remove_logit) < 1e-5 + force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove + force_remove_count_mask = force_remove_count <= num_force_remove + force_remove_mask = force_remove_count_mask & force_remove_mask + logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) + current_num_force_remove = tl.max(force_remove_count) + + logits_blk = tl.where(logits_blk > p_pivot, logits_blk, -float('inf')) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) @triton.jit def _topk_topp_kernel(LOGITS, PROBS, K, P, B, + SIGMA: tl.constexpr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr): @@ -495,7 +497,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, sq_avg_logit = tl.sum(logits_blk * logits_blk) / N std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - outlier_pivot = avg_logit + 2.5 * std_logit + outlier_pivot = avg_logit + SIGMA * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits and gather outliers for i in range(0,search_iters): @@ -513,13 +515,14 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) ############### START OF TOP-K CODE ############### - - ##### THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K IMPLEMENTATION - ##### CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS - ##### IF YOU NEED EXECATLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION - ##### AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE k = tl.load(K + row_id) if not (k == N): # All tokens are valid + + # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K IMPLEMENTATION + # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS + # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION + # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE + max_range = max_logit min_range = min_logit if num_outliers > k: @@ -578,7 +581,8 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, ############### START OF TOP-P CODE ############### p = tl.load(P + row_id) - if p != 1.0: + if p != 1.0: # All tokens are valid + second_max_logit = -float('inf') max_probs = 0.0 min_probs = 1.0 @@ -642,7 +646,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, mask_n = offs_n < search_range probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-6) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) # Check if any of the pivots are equal to k if p_pivots_sum_0 >= p: @@ -654,7 +658,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, max_range = p_pivot_0 num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-9: + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: p_pivot = p_pivot_0 # At least one value should be greater than p_pivot @@ -679,7 +683,8 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) if force_remove_logit != -float('inf'): - force_remove_mask = tl.abs(logits_blk - force_remove_logit) < 1e-6 + # Force remove duplicates + force_remove_mask = tl.abs(logits_blk - force_remove_logit) < 1e-5 force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove force_remove_count_mask = force_remove_count <= num_force_remove force_remove_mask = force_remove_count_mask & force_remove_mask @@ -699,25 +704,30 @@ def triton_apply_top_k_top_p( BLOCK_SIZE = 4096 NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE - debug = torch.full((20, 21), -float('inf'), device=logits.device) + SIGMA = 2.5 + debug = torch.full((4, 20), -float('inf'), device=logits.device) # print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " # f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " # f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") # print(f"Input logits: {logits}") - probs = torch.full_like(logits, -float('inf'), device=logits.device) + if k is not None and p is None: + probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES) + SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) elif k is None and p is not None: - _topp_kernel[(NUM_PROGRAMS,)](logits, probs, p, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES) + probs = torch.full_like(logits, -float('inf'), device=logits.device) + probs_2 = torch.full_like(logits, -float('inf'), device=logits.device) + _topp_kernel[(NUM_PROGRAMS,)](logits, probs, probs_2, p, batch_size, + SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) elif k is not None and p is not None: + probs = torch.full_like(logits, -float('inf'), device=logits.device) _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, - vocab_size, BLOCK_SIZE, NUM_TILES) + SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) # print(f"debug: {debug}") # print(f"Output logits: {logits}") - # print(f"Output probs: {probs}") + # print(f"Output probs: {probs[:, :12]}") return logits @torch.compile @@ -741,9 +751,14 @@ def apply_top_k_top_p( The logits tensor may be updated in-place. """ + # logits = torch.full_like(logits, -1.0, device=logits.device) + # logits[:12 ,:10] = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], device=logits.device) + # print(f"logits: {logits[:12, :12]}") + input_logits = logits.clone() original_logits = original_apply_top_k_top_p(input_logits, k, p) - # original_probs = torch.softmax(input_logits, dim=-1) + original_probs = torch.softmax(input_logits, dim=-1) + # print(f"original_probs: {original_probs[:12, :12]}") batch_size, vocab_size = logits.shape print(g_str("apply_top_k_top_p") + f" logits.shape: {batch_size} x {vocab_size}, p is None: {p is None}, k is None: {k is None}") @@ -765,11 +780,11 @@ def apply_top_k_top_p( error_rows = torch.unique(error_rows) num_error_rows = error_rows.shape[0] print(f"num_error_rows: {num_error_rows} - {error_rows}") - row_to_show = 12 if num_error_rows > 12 else num_error_rows - print(f"logits: {torch.sort(logits[error_rows], descending=True).values[:row_to_show, :12]}") - print(f"original_logits: {torch.sort(original_logits[error_rows], descending=True).values[:row_to_show, :12]}") - + row_to_show = 4 if num_error_rows > 4 else num_error_rows + print(f"logits: {torch.sort(logits[error_rows], descending=True).values[:row_to_show, :50]}") + print(f"original_logits: {torch.sort(original_logits[error_rows], descending=True).values[:row_to_show, :50]}") + print ("/////////////////////////////////////////") start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) out_dir = "./sampler_input_output" os.makedirs(out_dir, exist_ok=True) From 56a615f0041370fa6fb7d4d24edb692bd84b9f76 Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 26 Sep 2025 23:52:42 -0700 Subject: [PATCH 24/99] Deeseep tests Signed-off-by: js_park --- test.py | 4 +-- vllm/v1/sample/ops/topk_topp_sampler.py | 36 +++++++++---------------- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/test.py b/test.py index 5165e302d19f..1ef0b2fab45f 100644 --- a/test.py +++ b/test.py @@ -19,8 +19,8 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.7, seed=SEED) # Add seed to LLM initialization for global reproducibility -llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", seed=SEED) -# llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") +# llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", seed=SEED) +llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 0cce44be0b23..98194fce8e7c 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -189,7 +189,9 @@ def _topk_kernel(LOGITS, PROBS, K, B, if not (k == N): # All tokens are valid # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K IMPLEMENTATION - # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS + # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS, + # FOLLOWING THE CURRENT PYTHON BASED IMPLEMENTATION in apply_top_k_only(), WHICH ALSO + # INCLUDES ALL DUPLICATE LOGITS. # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE @@ -519,7 +521,9 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, if not (k == N): # All tokens are valid # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K IMPLEMENTATION - # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS + # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS, + # FOLLOWING THE CURRENT PYTHON BASED IMPLEMENTATION in apply_top_k_only(), WHICH ALSO + # INCLUDES ALL DUPLICATE LOGITS. # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE @@ -700,18 +704,13 @@ def triton_apply_top_k_top_p( k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: + batch_size, vocab_size = logits.shape BLOCK_SIZE = 4096 NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE SIGMA = 2.5 - debug = torch.full((4, 20), -float('inf'), device=logits.device) - # print(b_str("Launch params:") + f"logits.shape: {logits.shape}, probs.shape: {probs.shape}, " - # f"k.shape: {k.shape if k is not None else None}, p.shape: {p.shape if p is not None else None}, " - # f"batch_size: {batch_size}, vocab_size: {vocab_size}, BLOCK_SIZE: {BLOCK_SIZE}, NUM_TILES: {NUM_TILES}") - # print(f"Input logits: {logits}") - if k is not None and p is None: probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, @@ -725,9 +724,7 @@ def triton_apply_top_k_top_p( probs = torch.full_like(logits, -float('inf'), device=logits.device) _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) - # print(f"debug: {debug}") - # print(f"Output logits: {logits}") - # print(f"Output probs: {probs[:, :12]}") + return logits @torch.compile @@ -751,14 +748,8 @@ def apply_top_k_top_p( The logits tensor may be updated in-place. """ - # logits = torch.full_like(logits, -1.0, device=logits.device) - # logits[:12 ,:10] = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], device=logits.device) - # print(f"logits: {logits[:12, :12]}") - input_logits = logits.clone() original_logits = original_apply_top_k_top_p(input_logits, k, p) - original_probs = torch.softmax(input_logits, dim=-1) - # print(f"original_probs: {original_probs[:12, :12]}") batch_size, vocab_size = logits.shape print(g_str("apply_top_k_top_p") + f" logits.shape: {batch_size} x {vocab_size}, p is None: {p is None}, k is None: {k is None}") @@ -784,12 +775,11 @@ def apply_top_k_top_p( print(f"logits: {torch.sort(logits[error_rows], descending=True).values[:row_to_show, :50]}") print(f"original_logits: {torch.sort(original_logits[error_rows], descending=True).values[:row_to_show, :50]}") - print ("/////////////////////////////////////////") - start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) - out_dir = "./sampler_input_output" - os.makedirs(out_dir, exist_ok=True) - out_path = f"{out_dir}/llama8b_{start_time_str}.pt" - torch.save({"input_logits": input_logits, "p": p, "k": k, "output_logits": logits}, out_path) + # start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) + # out_dir = "./sampler_input_output" + # os.makedirs(out_dir, exist_ok=True) + # out_path = f"{out_dir}/llama8b_{start_time_str}.pt" + # torch.save({"input_logits": input_logits, "p": p, "k": k, "output_logits": logits}, out_path) return logits From 6bea89cd0a4196162c0dd5954a3bae9bb7a0fd6f Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 27 Sep 2025 00:57:28 -0700 Subject: [PATCH 25/99] Added env var VLLM_USE_TRITON_SAMPLER and automated test Signed-off-by: js_park --- test.py | 1 + test_triton_topk_topp.py | 123 +++++++++++++++++ vllm/envs.py | 6 + vllm/v1/sample/ops/topk_topp_sampler.py | 172 +++++++++--------------- 4 files changed, 194 insertions(+), 108 deletions(-) create mode 100644 test_triton_topk_topp.py diff --git a/test.py b/test.py index 1ef0b2fab45f..02ab987f53a0 100644 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ # For V1: Turn off multiprocessing to make scheduling deterministic os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_USE_TRITON_SAMPLER"] = "1" # Set a fixed seed for reproducibility SEED = 42 diff --git a/test_triton_topk_topp.py b/test_triton_topk_topp.py new file mode 100644 index 000000000000..04a0cf6a4fb1 --- /dev/null +++ b/test_triton_topk_topp.py @@ -0,0 +1,123 @@ +import torch +import time +import re +from itertools import product +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, apply_top_k_top_p_triton +from datetime import datetime + +def g_str(s): + return "\033[32m" + s + "\033[0m" +def r_str(s): + return "\033[31m" + s + "\033[0m" +def y_str(s): + return "\033[33m" + s + "\033[0m" +def b_str(s): + return "\033[34m" + s + "\033[0m" + +def print_to_log(s, log_file): + print(s) + # Remove the color codes + s = re.sub(r"\033\[[0-9;]*m", "", s) + with open(log_file, "a") as f: + f.write(s + "\n") + +def test_accuracy(logits, k, p): + input_logits = logits.clone() + original_logits = apply_top_k_top_p(input_logits, k, p) + logits = apply_top_k_top_p_triton(logits, k, p) + + torch.cuda.synchronize() + is_correct = torch.allclose(logits, original_logits) + + if not is_correct: + print_to_log(r_str("Error: logits are not close"), log_file) + error_rows = torch.where(logits != original_logits)[0] + error_rows = torch.unique(error_rows) + num_error_rows = error_rows.shape[0] + print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", log_file) + row_to_show = 12 if num_error_rows > 12 else num_error_rows + print_to_log(f"logits: {torch.sort(logits[error_rows], descending=True).values[:row_to_show, :50]}", log_file) + print_to_log(f"original_logits: {torch.sort(original_logits[error_rows], descending=True).values[:row_to_show, :50]}", log_file) + + return is_correct + +def test_time(logits, k, p, num_runs=256): + # We must clone the logits for each run to avoid modifying the original logits + input_logits_torch = [logits.clone() for _ in range(num_runs)] + input_logits_triton = [logits.clone() for _ in range(num_runs)] + + torch.cuda.synchronize() + start_time = time.time() + for _ in range(num_runs): + input_logits_torch[_] = apply_top_k_top_p(input_logits_torch[_], k, p) + torch.cuda.synchronize() + torch_time_taken = (time.time() - start_time) / num_runs + + torch.cuda.synchronize() + start_time = time.time() + for _ in range(num_runs): + input_logits_triton[_] = apply_top_k_top_p_triton(input_logits_triton[_], k, p) + torch.cuda.synchronize() + triton_time_taken = (time.time() - start_time) / num_runs + + return torch_time_taken, triton_time_taken + +if __name__ == "__main__": + date_str = datetime.now().strftime("%Y%m%d_%H%M%S") + batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 + vocab_size_list = [2**i for i in range(8, 19)] # 256 to 262144 + p_list = [None, "RAND"] + [0.1 * i for i in range(1, 10)] + k_list = [None, "RAND"] + [i for i in range(1, 10)] + [i for i in range(20, 210, 30)] + log_file = f"triton_topk_topp_test_{date_str}.log" + csv_file = f"triton_topk_topp_test_{date_str}.csv" + + print_to_log(y_str(f"Testing TopKTopPSampler with Triton"), log_file) + print_to_log(y_str(f"batch_size_list:") + f"{batch_size_list}", log_file) + print_to_log(y_str(f"vocab_size_list:") + f"{vocab_size_list}", log_file) + print_to_log(y_str(f"p_list:") + f"{p_list}", log_file) + print_to_log(y_str(f"k_list:") + f"{k_list}", log_file) + print_to_log(y_str(f"log_file:") + f"{log_file}", log_file) + print_to_log(y_str(f"csv_file:") + f"{csv_file}", log_file) + + with open(csv_file, "w") as f: + f.write("dist_generator,batch_size,vocab_size,p,k,is_correct,torch_time_taken,triton_time_taken,speedup\n") + + for batch_size, vocab_size, p, k in product(batch_size_list, vocab_size_list, p_list, k_list): + if p == "RAND" and k == "RAND": + continue + + logits_rand = torch.rand(batch_size, vocab_size, device="cuda") + logits_randn = torch.randn(batch_size, vocab_size, device="cuda") + logits_list = [("RAND", logits_rand), ("RANDN", logits_randn)] + + if p == "RAND": + p_tensor = torch.rand((batch_size,), device="cuda") * 0.95 + 0.05 + elif p is not None: + p_tensor = torch.full((batch_size,), p, device="cuda") + else: + p_tensor = None + + if k == "RAND": + k_tensor = torch.randint(1, vocab_size, (batch_size,), device="cuda") + elif k is not None: + k_tensor = torch.full((batch_size,), k, device="cuda") + else: + k_tensor = None + + for dist_generator, logits in logits_list: + print_to_log(y_str(f"--------------------------------"), log_file) + print_to_log(g_str(f"Testing ") + f"{dist_generator}" + + y_str(f" with batch_size: ") + f"{batch_size}" + + y_str(f" vocab_size: ") + f"{vocab_size}" + + y_str(f" p: ") + f"{p}" + + y_str(f" k: ") + f"{k}", log_file) + is_correct = test_accuracy(logits, k_tensor, p_tensor) + if not is_correct: + print_to_log(f"Error: logits are not close for batch_size: {batch_size}, vocab_size: {vocab_size}, dist_generator: {dist_generator}, p: {p}, k: {k}", log_file) + torch_time_taken, triton_time_taken = test_time(logits, k_tensor, p_tensor) + print_to_log(b_str(f"torch_time_taken: ") + f"{torch_time_taken}", log_file) + print_to_log(b_str(f"triton_time_taken: ") + f"{triton_time_taken}", log_file) + print_to_log(g_str(f"Triton Speedup over Torch: ") + f"{torch_time_taken / triton_time_taken:.8f}x", log_file) + with open(csv_file, "a") as f: + f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k},{is_correct},{torch_time_taken},{triton_time_taken},{torch_time_taken / triton_time_taken:.8f}\n") + print_to_log(y_str(f"--------------------------------\n"), log_file) \ No newline at end of file diff --git a/vllm/envs.py b/vllm/envs.py index 3991a789d80f..b9e69150a293 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -45,6 +45,7 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None + VLLM_USE_TRITON_SAMPLER: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -546,6 +547,10 @@ def get_vllm_port() -> Optional[int]: lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, + # If set, vllm will use triton sampler + "VLLM_USE_TRITON_SAMPLER": + lambda: bool(int(os.environ.get("VLLM_USE_TRITON_SAMPLER", "0"))), + # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), @@ -1391,6 +1396,7 @@ def compute_hash() -> str: "VLLM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ATTENTION_BACKEND", "VLLM_USE_FLASHINFER_SAMPLER", + "VLLM_USE_TRITON_SAMPLER", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM_E8M0", diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 98194fce8e7c..0901225c8c4e 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -26,20 +26,6 @@ except ImportError: is_flashinfer_available = False -def g_str(s): - return "\033[32m" + s + "\033[0m" -def r_str(s): - return "\033[31m" + s + "\033[0m" -def b_str(s): - return "\033[34m" + s + "\033[0m" -def y_str(s): - return "\033[33m" + s + "\033[0m" -def c_str(s): - return "\033[36m" + s + "\033[0m" -def m_str(s): - return "\033[35m" + s + "\033[0m" - - class TopKTopPSampler(nn.Module): """ Module that performs optional top-k and top-p filtering followed by @@ -74,6 +60,10 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: logger.info_once( "Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda + elif envs.VLLM_USE_TRITON_SAMPLER is not False: + logger.info_once( + "Using Triton for top-p & top-k sampling.") + self.forward = self.forward_triton else: logger.warning_once( "FlashInfer is available, but it is not enabled. " @@ -82,15 +72,21 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: "please set VLLM_USE_FLASHINFER_SAMPLER=1.") self.forward = self.forward_native else: - logger.warning_once( - "FlashInfer is not available. Falling back to the PyTorch-" - "native implementation of top-p & top-k sampling. For the " - "best performance, please install FlashInfer.") - self.forward = self.forward_native + if envs.VLLM_USE_TRITON_SAMPLER is not False: + logger.info_once( + "Using Triton for top-p & top-k sampling.") + self.forward = self.forward_triton + else: + logger.warning_once( + "FlashInfer is not available. Falling back to the PyTorch-" + "native implementation of top-p & top-k sampling. For the " + "best performance, please install FlashInfer.") + self.forward = self.forward_native else: self.forward = self.forward_native self.apply_top_k_top_p = apply_top_k_top_p + self.apply_top_k_top_p_triton = apply_top_k_top_p_triton def forward_native( self, @@ -113,6 +109,22 @@ def forward_native( probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators), logits_to_return + def forward_triton( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.apply_top_k_top_p_triton(logits, k, p) + logits_to_return = None + if self.logprobs_mode == "processed_logits": + logits_to_return = logits + elif self.logprobs_mode == "processed_logprobs": + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators), logits_to_return + def forward_cuda( self, logits: torch.Tensor, @@ -139,7 +151,7 @@ def forward_cuda( return flashinfer_sample(logits.contiguous(), k, p, generators), None -def original_apply_top_k_top_p( +def apply_top_k_top_p( logits: torch.Tensor, k: Optional[torch.Tensor], p: Optional[torch.Tensor], @@ -176,6 +188,34 @@ def original_apply_top_k_top_p( logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) return logits + +def apply_top_k_top_p_triton( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + + batch_size, vocab_size = logits.shape + BLOCK_SIZE = 4096 + NUM_PROGRAMS = 128 + NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE + SIGMA = 2.5 + + if k is not None and p is None: + probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) + _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, + SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) + elif k is None and p is not None: + probs = torch.full_like(logits, -float('inf'), device=logits.device) + probs_2 = torch.full_like(logits, -float('inf'), device=logits.device) + _topp_kernel[(NUM_PROGRAMS,)](logits, probs, probs_2, p, batch_size, + SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) + elif k is not None and p is not None: + probs = torch.full_like(logits, -float('inf'), device=logits.device) + _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, + SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) + return logits + @triton.jit def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, @@ -193,7 +233,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, # FOLLOWING THE CURRENT PYTHON BASED IMPLEMENTATION in apply_top_k_only(), WHICH ALSO # INCLUDES ALL DUPLICATE LOGITS. # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION - # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE + # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE. k_pivot = -float('inf') @@ -268,7 +308,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, k_pivot = k_pivot_1 elif k_pivots_num_2 == k: k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we updatae the range + # If none of the pivots are equal to k, we update the range elif k_pivots_num_2 > k: min_range = k_pivot_2 elif k_pivots_num_1 > k: @@ -461,7 +501,7 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, logits_blk = tl.where(logits_blk > p_pivot, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) - + @triton.jit def _topk_topp_kernel(LOGITS, PROBS, K, P, B, @@ -525,7 +565,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, # FOLLOWING THE CURRENT PYTHON BASED IMPLEMENTATION in apply_top_k_only(), WHICH ALSO # INCLUDES ALL DUPLICATE LOGITS. # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION - # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE + # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE. max_range = max_logit min_range = min_logit @@ -697,90 +737,6 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, logits_blk = tl.where(logits_blk > pivot, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) - - -def triton_apply_top_k_top_p( - logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], -) -> torch.Tensor: - - batch_size, vocab_size = logits.shape - BLOCK_SIZE = 4096 - NUM_PROGRAMS = 128 - NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE - SIGMA = 2.5 - - if k is not None and p is None: - probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) - _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, - SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) - elif k is None and p is not None: - probs = torch.full_like(logits, -float('inf'), device=logits.device) - probs_2 = torch.full_like(logits, -float('inf'), device=logits.device) - _topp_kernel[(NUM_PROGRAMS,)](logits, probs, probs_2, p, batch_size, - SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) - elif k is not None and p is not None: - probs = torch.full_like(logits, -float('inf'), device=logits.device) - _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, - SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) - - return logits - -@torch.compile -def compiled_apply_top_k_top_p( - logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], -) -> torch.Tensor: - return original_apply_top_k_top_p(logits, k, p) - -def apply_top_k_top_p( - logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], -) -> torch.Tensor: - """Apply top-k and top-p masks to the logits. - - If a top-p is used, this function will sort the logits tensor, - which can be slow for large batches. - - The logits tensor may be updated in-place. - """ - - input_logits = logits.clone() - original_logits = original_apply_top_k_top_p(input_logits, k, p) - - batch_size, vocab_size = logits.shape - print(g_str("apply_top_k_top_p") + f" logits.shape: {batch_size} x {vocab_size}, p is None: {p is None}, k is None: {k is None}") - - torch.cuda.synchronize() - start_time = time.time() - - # logits = original_apply_top_k_top_p(logits, k, p) - # logits = compiled_apply_top_k_top_p(logits, k, p) - logits = triton_apply_top_k_top_p(logits, k, p) - - torch.cuda.synchronize() - time_taken = time.time() - start_time - print(y_str(f"apply_top_k_top_p done in {time_taken} seconds")) - - if not torch.allclose(logits, original_logits): - print(r_str("Error: logits are not close")) - error_rows = torch.where(logits != original_logits)[0] - error_rows = torch.unique(error_rows) - num_error_rows = error_rows.shape[0] - print(f"num_error_rows: {num_error_rows} - {error_rows}") - row_to_show = 4 if num_error_rows > 4 else num_error_rows - print(f"logits: {torch.sort(logits[error_rows], descending=True).values[:row_to_show, :50]}") - print(f"original_logits: {torch.sort(original_logits[error_rows], descending=True).values[:row_to_show, :50]}") - - # start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(start_time)) - # out_dir = "./sampler_input_output" - # os.makedirs(out_dir, exist_ok=True) - # out_path = f"{out_dir}/llama8b_{start_time_str}.pt" - # torch.save({"input_logits": input_logits, "p": p, "k": k, "output_logits": logits}, out_path) - return logits def apply_top_k_only( From 5575c676da749f0f67fb0b8ed90e0a65508cd0fa Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 27 Sep 2025 01:21:12 -0700 Subject: [PATCH 26/99] Linter Signed-off-by: js_park --- test.py | 4 +- test_triton_topk_topp.py | 122 ++++++---- vllm/v1/sample/ops/topk_topp_sampler.py | 304 ++++++++++++++---------- 3 files changed, 266 insertions(+), 164 deletions(-) diff --git a/test.py b/test.py index 02ab987f53a0..84753533001b 100644 --- a/test.py +++ b/test.py @@ -1,4 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os + from vllm import LLM, SamplingParams # For V1: Turn off multiprocessing to make scheduling deterministic @@ -31,4 +34,3 @@ prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - diff --git a/test_triton_topk_topp.py b/test_triton_topk_topp.py index 04a0cf6a4fb1..477c77751993 100644 --- a/test_triton_topk_topp.py +++ b/test_triton_topk_topp.py @@ -1,19 +1,32 @@ -import torch +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -import re -from itertools import product -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, apply_top_k_top_p_triton from datetime import datetime +from itertools import product + +import regex as re +import torch + +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + apply_top_k_top_p_triton) + def g_str(s): return "\033[32m" + s + "\033[0m" + + def r_str(s): return "\033[31m" + s + "\033[0m" + + def y_str(s): return "\033[33m" + s + "\033[0m" + + def b_str(s): return "\033[34m" + s + "\033[0m" + def print_to_log(s, log_file): print(s) # Remove the color codes @@ -21,11 +34,12 @@ def print_to_log(s, log_file): with open(log_file, "a") as f: f.write(s + "\n") + def test_accuracy(logits, k, p): input_logits = logits.clone() original_logits = apply_top_k_top_p(input_logits, k, p) logits = apply_top_k_top_p_triton(logits, k, p) - + torch.cuda.synchronize() is_correct = torch.allclose(logits, original_logits) @@ -34,18 +48,25 @@ def test_accuracy(logits, k, p): error_rows = torch.where(logits != original_logits)[0] error_rows = torch.unique(error_rows) num_error_rows = error_rows.shape[0] - print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", log_file) + print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", + log_file) row_to_show = 12 if num_error_rows > 12 else num_error_rows - print_to_log(f"logits: {torch.sort(logits[error_rows], descending=True).values[:row_to_show, :50]}", log_file) - print_to_log(f"original_logits: {torch.sort(original_logits[error_rows], descending=True).values[:row_to_show, :50]}", log_file) + logits_to_show = torch.sort(logits[error_rows], descending=True).values + logits_to_show = logits_to_show[:row_to_show, :50] + print_to_log(f"logits: {logits_to_show}", log_file) + original_logits_to_show = \ + torch.sort(original_logits[error_rows], descending=True).values + original_logits_to_show = original_logits_to_show[:row_to_show, :50] + print_to_log(f"original_logits: {original_logits_to_show}", log_file) return is_correct + def test_time(logits, k, p, num_runs=256): - # We must clone the logits for each run to avoid modifying the original logits + # We must clone the logits for each run to avoid modifying the original input_logits_torch = [logits.clone() for _ in range(num_runs)] input_logits_triton = [logits.clone() for _ in range(num_runs)] - + torch.cuda.synchronize() start_time = time.time() for _ in range(num_runs): @@ -56,68 +77,87 @@ def test_time(logits, k, p, num_runs=256): torch.cuda.synchronize() start_time = time.time() for _ in range(num_runs): - input_logits_triton[_] = apply_top_k_top_p_triton(input_logits_triton[_], k, p) + input_logits_triton[_] = apply_top_k_top_p_triton( + input_logits_triton[_], k, p) torch.cuda.synchronize() triton_time_taken = (time.time() - start_time) / num_runs return torch_time_taken, triton_time_taken + if __name__ == "__main__": date_str = datetime.now().strftime("%Y%m%d_%H%M%S") - batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 - vocab_size_list = [2**i for i in range(8, 19)] # 256 to 262144 + batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 + vocab_size_list = [2**i for i in range(8, 19)] # 256 to 262144 p_list = [None, "RAND"] + [0.1 * i for i in range(1, 10)] - k_list = [None, "RAND"] + [i for i in range(1, 10)] + [i for i in range(20, 210, 30)] + k_list = [None, "RAND"] + [i for i in range(1, 10) + ] + [i for i in range(20, 210, 30)] log_file = f"triton_topk_topp_test_{date_str}.log" csv_file = f"triton_topk_topp_test_{date_str}.csv" - print_to_log(y_str(f"Testing TopKTopPSampler with Triton"), log_file) - print_to_log(y_str(f"batch_size_list:") + f"{batch_size_list}", log_file) - print_to_log(y_str(f"vocab_size_list:") + f"{vocab_size_list}", log_file) - print_to_log(y_str(f"p_list:") + f"{p_list}", log_file) - print_to_log(y_str(f"k_list:") + f"{k_list}", log_file) - print_to_log(y_str(f"log_file:") + f"{log_file}", log_file) - print_to_log(y_str(f"csv_file:") + f"{csv_file}", log_file) + print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) + print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) + print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) + print_to_log(y_str("p_list:") + f"{p_list}", log_file) + print_to_log(y_str("k_list:") + f"{k_list}", log_file) + print_to_log(y_str("log_file:") + f"{log_file}", log_file) + print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) with open(csv_file, "w") as f: - f.write("dist_generator,batch_size,vocab_size,p,k,is_correct,torch_time_taken,triton_time_taken,speedup\n") + f.write("dist_generator,batch_size,vocab_size,p,k,is_correct," + "torch_time_taken,triton_time_taken,speedup\n") - for batch_size, vocab_size, p, k in product(batch_size_list, vocab_size_list, p_list, k_list): + for batch_size, vocab_size, p, k in product(batch_size_list, + vocab_size_list, p_list, + k_list): if p == "RAND" and k == "RAND": continue - + logits_rand = torch.rand(batch_size, vocab_size, device="cuda") logits_randn = torch.randn(batch_size, vocab_size, device="cuda") logits_list = [("RAND", logits_rand), ("RANDN", logits_randn)] if p == "RAND": - p_tensor = torch.rand((batch_size,), device="cuda") * 0.95 + 0.05 + p_tensor = torch.rand((batch_size, ), device="cuda") * 0.95 + 0.05 elif p is not None: - p_tensor = torch.full((batch_size,), p, device="cuda") + p_tensor = torch.full((batch_size, ), p, device="cuda") else: p_tensor = None if k == "RAND": - k_tensor = torch.randint(1, vocab_size, (batch_size,), device="cuda") + k_tensor = torch.randint(1, + vocab_size, (batch_size, ), + device="cuda") elif k is not None: - k_tensor = torch.full((batch_size,), k, device="cuda") + k_tensor = torch.full((batch_size, ), k, device="cuda") else: k_tensor = None for dist_generator, logits in logits_list: - print_to_log(y_str(f"--------------------------------"), log_file) - print_to_log(g_str(f"Testing ") + f"{dist_generator}" - + y_str(f" with batch_size: ") + f"{batch_size}" - + y_str(f" vocab_size: ") + f"{vocab_size}" - + y_str(f" p: ") + f"{p}" - + y_str(f" k: ") + f"{k}", log_file) + print_to_log(y_str("--------------------------------"), log_file) + print_to_log( + g_str("Testing ") + f"{dist_generator}" + + y_str(" with batch_size: ") + f"{batch_size}" + + y_str(" vocab_size: ") + f"{vocab_size}" + y_str(" p: ") + + f"{p}" + y_str(" k: ") + f"{k}", log_file) is_correct = test_accuracy(logits, k_tensor, p_tensor) if not is_correct: - print_to_log(f"Error: logits are not close for batch_size: {batch_size}, vocab_size: {vocab_size}, dist_generator: {dist_generator}, p: {p}, k: {k}", log_file) - torch_time_taken, triton_time_taken = test_time(logits, k_tensor, p_tensor) - print_to_log(b_str(f"torch_time_taken: ") + f"{torch_time_taken}", log_file) - print_to_log(b_str(f"triton_time_taken: ") + f"{triton_time_taken}", log_file) - print_to_log(g_str(f"Triton Speedup over Torch: ") + f"{torch_time_taken / triton_time_taken:.8f}x", log_file) + print_to_log( + f"Error: logits are not close for batch_size: {batch_size}," + f" vocab_size: {vocab_size}, dist_generator: " + f"{dist_generator}, p: {p}, k: {k}", log_file) + torch_time_taken, triton_time_taken = test_time( + logits, k_tensor, p_tensor) + print_to_log( + b_str("torch_time_taken: ") + f"{torch_time_taken}", log_file) + print_to_log( + b_str("triton_time_taken: ") + f"{triton_time_taken}", + log_file) + print_to_log( + g_str("Triton Speedup over Torch: ") + + f"{torch_time_taken / triton_time_taken:.8f}x", log_file) with open(csv_file, "a") as f: - f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k},{is_correct},{torch_time_taken},{triton_time_taken},{torch_time_taken / triton_time_taken:.8f}\n") - print_to_log(y_str(f"--------------------------------\n"), log_file) \ No newline at end of file + f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k}," + f"{is_correct},{torch_time_taken},{triton_time_taken}," + f"{torch_time_taken / triton_time_taken:.8f}\n") + print_to_log(y_str("--------------------------------\n"), log_file) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index c53b9d92ae22..e480a24894a7 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1,16 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -from datetime import timedelta -from tkinter import NO from typing import Optional import torch import torch.nn as nn import triton import triton.language as tl -import time from packaging import version from vllm import envs @@ -25,7 +21,8 @@ is_flashinfer_available = True except ImportError: is_flashinfer_available = False - + + class TopKTopPSampler(nn.Module): """ Module that performs optional top-k and top-p filtering followed by @@ -78,9 +75,10 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: self.forward = self.forward_triton else: logger.warning_once( - "FlashInfer is not available. Falling back to the PyTorch-" - "native implementation of top-p & top-k sampling. For the " - "best performance, please install FlashInfer.") + "FlashInfer is not available. Falling back to the " + "PyTorch-native implementation of top-p & top-k " + "sampling. For the best performance, please install " + "FlashInfer.") self.forward = self.forward_native elif current_platform.is_cpu(): self.forward = self.forward_cpu @@ -200,7 +198,7 @@ def apply_top_k_top_p( """Apply top-k and top-p masks to the logits. """ logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - + if p is None: if k is None: return logits @@ -241,40 +239,43 @@ def apply_top_k_top_p_triton( NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE SIGMA = 2.5 - + if k is not None and p is None: - probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) - _topk_kernel[(NUM_PROGRAMS,)](logits, probs, k, batch_size, - SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) + probs = torch.full((NUM_PROGRAMS, vocab_size), + -float('inf'), + device=logits.device) + _topk_kernel[(NUM_PROGRAMS, )](logits, probs, k, batch_size, SIGMA, + vocab_size, BLOCK_SIZE, NUM_TILES) elif k is None and p is not None: probs = torch.full_like(logits, -float('inf'), device=logits.device) probs_2 = torch.full_like(logits, -float('inf'), device=logits.device) - _topp_kernel[(NUM_PROGRAMS,)](logits, probs, probs_2, p, batch_size, - SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) + _topp_kernel[(NUM_PROGRAMS, )](logits, probs, probs_2, p, batch_size, + SIGMA, vocab_size, BLOCK_SIZE, + NUM_TILES) elif k is not None and p is not None: probs = torch.full_like(logits, -float('inf'), device=logits.device) - _topk_topp_kernel[(NUM_PROGRAMS,)](logits, probs, k, p, batch_size, - SIGMA, vocab_size, BLOCK_SIZE, NUM_TILES) + _topk_topp_kernel[(NUM_PROGRAMS, )](logits, probs, k, p, batch_size, + SIGMA, vocab_size, BLOCK_SIZE, + NUM_TILES) return logits + @triton.jit -def _topk_kernel(LOGITS, PROBS, K, B, - SIGMA: tl.constexpr, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr): +def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): k = tl.load(K + row_id) - if not (k == N): # All tokens are valid + if k != N: # All tokens are valid - # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K IMPLEMENTATION - # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS, - # FOLLOWING THE CURRENT PYTHON BASED IMPLEMENTATION in apply_top_k_only(), WHICH ALSO - # INCLUDES ALL DUPLICATE LOGITS. - # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION - # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE. + # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K + # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, + # WHICH MAY RETURN MORE THAN K LOGITS, + # FOLLOWING THE IMPLEMENTATION in apply_top_k_only(). + # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P + # IMPLEMENTATION AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT + # USING THE FORCE_REMOVE_LOGIT VARIABLE k_pivot = -float('inf') @@ -300,28 +301,32 @@ def _topk_kernel(LOGITS, PROBS, K, B, outlier_pivot = avg_logit + SIGMA * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits and gather outliers - for i in range(0,search_iters): + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + logits_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) outlier_mask = (logits_blk > outlier_pivot) & mask_n num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) num_outliers += num_blk_outliers write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) - + max_range = max_logit min_range = min_logit if num_outliers > k: max_range = max_logit - min_range = outlier_pivot + min_range = outlier_pivot search_addr = PROBS_ROW search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 @@ -336,7 +341,9 @@ def _topk_kernel(LOGITS, PROBS, K, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) + logits_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=-float('inf')) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) @@ -344,7 +351,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, # Check if any of the pivots are equal to k if k_pivots_num_0 == k: - k_pivot = k_pivot_0 + k_pivot = k_pivot_0 elif k_pivots_num_1 == k: k_pivot = k_pivot_1 elif k_pivots_num_2 == k: @@ -362,7 +369,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, max_range = k_pivot_1 elif k_pivots_num_2 < k: max_range = k_pivot_2 - + num_iters += 1 if num_iters >= 18: k_pivot = k_pivot_0 @@ -379,16 +386,14 @@ def _topk_kernel(LOGITS, PROBS, K, B, @triton.jit -def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, - SIGMA: tl.constexpr, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, +def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, + N: tl.constexpr, BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): p = tl.load(P + row_id) - if p != 1.0: # All tokens are valid + if p != 1.0: # All tokens are valid p_pivot = -float('inf') @@ -402,13 +407,15 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, max_logit = -float('inf') min_logit = float('inf') - - # The Pytorch version removes the earlier duplicates if there are more than one duplicates + + # The Pytorch version removes the earlier duplicates + # if there are more than one duplicates force_remove_logit = -float('inf') num_force_remove = tl.zeros((), dtype=tl.uint32) # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if N < BLOCK_SIZE OR all logits are the same + # May produce incorrect results if N < BLOCK_SIZE + # OR all logits are the same offs = tl.arange(0, BLOCK_SIZE) mask_n = offs < N logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) @@ -421,21 +428,25 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, sum_outlier_probs = 0.0 sum_exp_logits = 0.0 - + # First pass: compute max and min logits - for i in range(0,search_iters): + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - + # Second pass: Calculate exp logits and sum for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) @@ -447,15 +458,16 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) probs_blk = probs_blk / sum_exp_logits tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - + outlier_mask = (probs_blk > outlier_prob) & mask_n sum_outlier_probs += tl.sum(outlier_mask * probs_blk) num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) num_outliers += num_blk_outliers write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) @@ -464,10 +476,11 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, min_range = tl.exp(min_logit - max_logit) / sum_exp_logits if sum_outlier_probs > p: - min_range = outlier_prob + min_range = outlier_prob search_addr = PROBS_2_ROW search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) second_max_logit = -float('inf') @@ -487,19 +500,27 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=0.0) + + masked_larger_0 = tl.where(probs_blk > p_pivot_0, + probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, + tl.min(masked_larger_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + p_pivots_sum_0 += tl.sum(probs_blk * + (probs_blk > p_pivot_0)) - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=0.0) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) + num_min_larger_0 += tl.sum( + tl.abs(probs_blk - min_larger_0) < 1e-7) # Check if any of the pivots are equal to k if p_pivots_sum_0 >= p: @@ -509,18 +530,21 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, min_range = p_pivot_0 else: max_range = p_pivot_0 - + num_iters += 1 if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: p_pivot = p_pivot_0 - + # At least one value should be greater than p_pivot if p_pivot >= max_logit: p_pivot = second_max_logit elif num_min_larger_0 > 1: - # Force remove duplicates (p_pivot is made to include all duplicates if it falls on the duplicates) - num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, tl.uint32) - force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit + # Force remove duplicates (p_pivot is made to include all + # duplicates if it falls on the duplicates) + num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, + tl.uint32) + force_remove_logit = tl.log( + min_larger_0 * sum_exp_logits) + max_logit p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit @@ -530,25 +554,31 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) if force_remove_logit != -float('inf'): - force_remove_mask = tl.abs(logits_blk - force_remove_logit) < 1e-5 - force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove - force_remove_count_mask = force_remove_count <= num_force_remove - force_remove_mask = force_remove_count_mask & force_remove_mask - logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) + force_remove_mask = tl.abs(logits_blk - + force_remove_logit) < 1e-5 + force_remove_count = tl.cumsum( + force_remove_mask) + current_num_force_remove + force_remove_count_mask = \ + force_remove_count <= num_force_remove + force_remove_mask = \ + force_remove_count_mask & force_remove_mask + logits_blk = tl.where(force_remove_mask, -float('inf'), + logits_blk) current_num_force_remove = tl.max(force_remove_count) - logits_blk = tl.where(logits_blk > p_pivot, logits_blk, -float('inf')) + logits_blk = tl.where(logits_blk > p_pivot, logits_blk, + -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) @triton.jit -def _topk_topp_kernel(LOGITS, PROBS, K, P, B, - SIGMA: tl.constexpr, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, +def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, + N: tl.constexpr, BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) @@ -567,7 +597,8 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, min_logit = float('inf') avg_logit = -float('inf') - # The Pytorch version removes the earlier duplicates if there are more than one duplicates + # The Pytorch version removes the earlier duplicates + # if there are more than one duplicates force_remove_logit = -float('inf') num_force_remove = tl.zeros((), dtype=tl.uint32) @@ -583,39 +614,44 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, outlier_pivot = avg_logit + SIGMA * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits and gather outliers - for i in range(0,search_iters): + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + logits_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) outlier_mask = (logits_blk > outlier_pivot) & mask_n num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) num_outliers += num_blk_outliers write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) - + ############### START OF TOP-K CODE ############### k = tl.load(K + row_id) - if not (k == N): # All tokens are valid + if k != N: # All tokens are valid - # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K IMPLEMENTATION - # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, WHICH MAY RETURN MORE THAN K LOGITS, - # FOLLOWING THE CURRENT PYTHON BASED IMPLEMENTATION in apply_top_k_only(), WHICH ALSO - # INCLUDES ALL DUPLICATE LOGITS. - # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P IMPLEMENTATION - # AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT USING THE FORCE_REMOVE_LOGIT VARIABLE. + # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K + # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, + # WHICH MAY RETURN MORE THAN K LOGITS, + # FOLLOWING THE IMPLEMENTATION in apply_top_k_only(). + # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P + # IMPLEMENTATION AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT + # USING THE FORCE_REMOVE_LOGIT VARIABLE. max_range = max_logit min_range = min_logit if num_outliers > k: max_range = max_logit - min_range = outlier_pivot + min_range = outlier_pivot search_addr = PROBS_ROW search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast((num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 @@ -630,7 +666,9 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) + logits_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=-float('inf')) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) @@ -638,7 +676,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, # Check if any of the pivots are equal to k if k_pivots_num_0 == k: - k_pivot = k_pivot_0 + k_pivot = k_pivot_0 elif k_pivots_num_1 == k: k_pivot = k_pivot_1 elif k_pivots_num_2 == k: @@ -656,17 +694,17 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, max_range = k_pivot_1 elif k_pivots_num_2 < k: max_range = k_pivot_2 - + num_iters += 1 if num_iters >= 18: - k_pivot = k_pivot_0 + k_pivot = k_pivot_0 ############### END OF TOP-K CODE ############### ############### START OF TOP-P CODE ############### p = tl.load(P + row_id) - if p != 1.0: # All tokens are valid + if p != 1.0: # All tokens are valid second_max_logit = -float('inf') max_probs = 0.0 @@ -677,18 +715,21 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = tl.where(probs_blk > k_pivot, probs_blk, -float('inf')) + probs_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=-float('inf')) + probs_blk = tl.where(probs_blk > k_pivot, probs_blk, + -float('inf')) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) second_max_mask = probs_blk * (probs_blk < max_probs) - second_max_logit = tl.maximum(second_max_logit, tl.max(second_max_mask)) + second_max_logit = tl.maximum(second_max_logit, + tl.max(second_max_mask)) # Fourth pass: Compute probs (softmax) - exp_avg = tl.exp(avg_logit) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range @@ -719,19 +760,27 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(PROBS_ROW + offs_n, + mask=mask_n, + other=0.0) + + masked_larger_0 = tl.where(probs_blk > p_pivot_0, + probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, + tl.min(masked_larger_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + p_pivots_sum_0 += tl.sum(probs_blk * + (probs_blk > p_pivot_0)) - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load(PROBS_ROW + offs_n, + mask=mask_n, + other=0.0) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) + num_min_larger_0 += tl.sum( + tl.abs(probs_blk - min_larger_0) < 1e-7) # Check if any of the pivots are equal to k if p_pivots_sum_0 >= p: @@ -741,21 +790,24 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, min_range = p_pivot_0 else: max_range = p_pivot_0 - + num_iters += 1 if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: p_pivot = p_pivot_0 - + # At least one value should be greater than p_pivot if p_pivot >= max_logit: p_pivot = second_max_logit elif num_min_larger_0 > 1: - # Force remove duplicates (p_pivot is made to include all duplicates if it falls on the duplicates) - num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, tl.uint32) - force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit + # Force remove duplicates (p_pivot is made to include all + # duplicates if it falls on the duplicates) + num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, + tl.uint32) + force_remove_logit = tl.log( + min_larger_0 * sum_exp_logits) + max_logit p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - + ############### END OF TOP-P CODE ############### # Sixth pass: Apply mask @@ -765,18 +817,26 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < N - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) if force_remove_logit != -float('inf'): # Force remove duplicates - force_remove_mask = tl.abs(logits_blk - force_remove_logit) < 1e-5 - force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove - force_remove_count_mask = force_remove_count <= num_force_remove - force_remove_mask = force_remove_count_mask & force_remove_mask - logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) + force_remove_mask = tl.abs(logits_blk - + force_remove_logit) < 1e-5 + force_remove_count = tl.cumsum( + force_remove_mask) + current_num_force_remove + force_remove_count_mask = \ + force_remove_count <= num_force_remove + force_remove_mask = \ + force_remove_count_mask & force_remove_mask + logits_blk = tl.where(force_remove_mask, -float('inf'), + logits_blk) current_num_force_remove = tl.max(force_remove_count) - logits_blk = tl.where(logits_blk > pivot, logits_blk, -float('inf')) + logits_blk = tl.where(logits_blk > pivot, logits_blk, + -float('inf')) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) From 3342235a0178691b1c65e7254828e105e1a19463 Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 27 Sep 2025 01:59:51 -0700 Subject: [PATCH 27/99] Tests Signed-off-by: js_park --- test_triton_topk_topp.py | 22 ++++++++++++++-------- vllm/v1/sample/ops/topk_topp_sampler.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/test_triton_topk_topp.py b/test_triton_topk_topp.py index 477c77751993..1957d38d584f 100644 --- a/test_triton_topk_topp.py +++ b/test_triton_topk_topp.py @@ -45,18 +45,23 @@ def test_accuracy(logits, k, p): if not is_correct: print_to_log(r_str("Error: logits are not close"), log_file) - error_rows = torch.where(logits != original_logits)[0] + error_mask = torch.abs(logits - original_logits) > 1e-5 + error_rows = torch.where(error_mask)[0] error_rows = torch.unique(error_rows) num_error_rows = error_rows.shape[0] + error_cols = torch.where(error_mask)[1] + error_cols = torch.unique(error_cols) + num_error_cols = error_cols.shape[0] print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", log_file) - row_to_show = 12 if num_error_rows > 12 else num_error_rows + print_to_log(f"num_error_cols: {num_error_cols}", log_file) + row_to_show = 5 if num_error_rows > 5 else num_error_rows logits_to_show = torch.sort(logits[error_rows], descending=True).values - logits_to_show = logits_to_show[:row_to_show, :50] + logits_to_show = logits_to_show[:row_to_show, :20] print_to_log(f"logits: {logits_to_show}", log_file) original_logits_to_show = \ torch.sort(original_logits[error_rows], descending=True).values - original_logits_to_show = original_logits_to_show[:row_to_show, :50] + original_logits_to_show = original_logits_to_show[:row_to_show, :20] print_to_log(f"original_logits: {original_logits_to_show}", log_file) return is_correct @@ -88,10 +93,11 @@ def test_time(logits, k, p, num_runs=256): if __name__ == "__main__": date_str = datetime.now().strftime("%Y%m%d_%H%M%S") batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 - vocab_size_list = [2**i for i in range(8, 19)] # 256 to 262144 - p_list = [None, "RAND"] + [0.1 * i for i in range(1, 10)] - k_list = [None, "RAND"] + [i for i in range(1, 10) - ] + [i for i in range(20, 210, 30)] + vocab_size_list = [2**i for i in range(10, 19, 2)] # 1024 to 131072 + p_list = [None, "RAND"] + [0.2 * i + for i in range(1, 6)] + [0.9, 0.95, 0.99] + k_list = [None, "RAND"] + [i for i in range(1, 10, 2) + ] + [i for i in range(20, 210, 40)] log_file = f"triton_topk_topp_test_{date_str}.log" csv_file = f"triton_topk_topp_test_{date_str}.csv" diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e480a24894a7..5434025789f9 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -235,7 +235,7 @@ def apply_top_k_top_p_triton( ) -> torch.Tensor: batch_size, vocab_size = logits.shape - BLOCK_SIZE = 4096 + BLOCK_SIZE = 8192 NUM_PROGRAMS = 128 NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE SIGMA = 2.5 From 9bb0fbbc95f92b4b6cf168f8d1b0ae7cb35613d7 Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 27 Sep 2025 16:47:27 -0700 Subject: [PATCH 28/99] Added Triton autotune Signed-off-by: js_park --- test.py | 4 +-- test_triton_topk_topp.py | 38 +++++++++++++-------- vllm/v1/sample/ops/topk_topp_sampler.py | 45 +++++++++++++++++-------- 3 files changed, 57 insertions(+), 30 deletions(-) diff --git a/test.py b/test.py index 84753533001b..464a933c347a 100644 --- a/test.py +++ b/test.py @@ -23,8 +23,8 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.7, seed=SEED) # Add seed to LLM initialization for global reproducibility -# llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", seed=SEED) -llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") +llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", seed=SEED) +# llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") outputs = llm.generate(prompts, sampling_params) diff --git a/test_triton_topk_topp.py b/test_triton_topk_topp.py index 1957d38d584f..491c8dd9b175 100644 --- a/test_triton_topk_topp.py +++ b/test_triton_topk_topp.py @@ -67,10 +67,15 @@ def test_accuracy(logits, k, p): return is_correct -def test_time(logits, k, p, num_runs=256): +def test_time(logits, k, p, num_runs=30, num_warmup=5): # We must clone the logits for each run to avoid modifying the original + warmup_tensor = logits.clone() + for _ in range(num_warmup): + apply_top_k_top_p(warmup_tensor, k, p) + apply_top_k_top_p_triton(warmup_tensor, k, p) + torch.cuda.synchronize() + input_logits_torch = [logits.clone() for _ in range(num_runs)] - input_logits_triton = [logits.clone() for _ in range(num_runs)] torch.cuda.synchronize() start_time = time.time() @@ -79,6 +84,8 @@ def test_time(logits, k, p, num_runs=256): torch.cuda.synchronize() torch_time_taken = (time.time() - start_time) / num_runs + input_logits_triton = [logits.clone() for _ in range(num_runs)] + torch.cuda.synchronize() start_time = time.time() for _ in range(num_runs): @@ -92,12 +99,18 @@ def test_time(logits, k, p, num_runs=256): if __name__ == "__main__": date_str = datetime.now().strftime("%Y%m%d_%H%M%S") + # batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 + # vocab_size_list = [2**i for i in range(12, 18)] + [102400, 128256] + # p_list = [None, "RAND"] + [0.2 * i + # for i in range(1, 6)] + [0.9, 0.95, 0.99] + # k_list = [None, "RAND"] + [i for i in range(1, 11, 3) + # ] + [i for i in range(20, 230, 50)] + batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 - vocab_size_list = [2**i for i in range(10, 19, 2)] # 1024 to 131072 - p_list = [None, "RAND"] + [0.2 * i - for i in range(1, 6)] + [0.9, 0.95, 0.99] - k_list = [None, "RAND"] + [i for i in range(1, 10, 2) - ] + [i for i in range(20, 210, 40)] + vocab_size_list = [4096, 32768, 131072, 102400, 128256] + p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] + k_list = [None, "RAND", 5, 10, 50, 100, 200] + log_file = f"triton_topk_topp_test_{date_str}.log" csv_file = f"triton_topk_topp_test_{date_str}.csv" @@ -119,21 +132,18 @@ def test_time(logits, k, p, num_runs=256): if p == "RAND" and k == "RAND": continue - logits_rand = torch.rand(batch_size, vocab_size, device="cuda") - logits_randn = torch.randn(batch_size, vocab_size, device="cuda") - logits_list = [("RAND", logits_rand), ("RANDN", logits_randn)] + logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 + logits_list = [("RANDN", logits_randn)] if p == "RAND": - p_tensor = torch.rand((batch_size, ), device="cuda") * 0.95 + 0.05 + p_tensor = torch.rand((batch_size, ), device="cuda") * 0.9 + 0.05 elif p is not None: p_tensor = torch.full((batch_size, ), p, device="cuda") else: p_tensor = None if k == "RAND": - k_tensor = torch.randint(1, - vocab_size, (batch_size, ), - device="cuda") + k_tensor = torch.randint(1, 300, (batch_size, ), device="cuda") elif k is not None: k_tensor = torch.full((batch_size, ), k, device="cuda") else: diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 5434025789f9..6983cbc97960 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -235,36 +235,45 @@ def apply_top_k_top_p_triton( ) -> torch.Tensor: batch_size, vocab_size = logits.shape - BLOCK_SIZE = 8192 - NUM_PROGRAMS = 128 - NUM_TILES = (vocab_size + BLOCK_SIZE - 1) // BLOCK_SIZE - SIGMA = 2.5 + device_prop = torch.cuda.get_device_properties(logits.device) + NUM_PROGRAMS = device_prop.multi_processor_count + SIGMA = 2.15 # Top 0.03 outliers if k is not None and p is None: probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) _topk_kernel[(NUM_PROGRAMS, )](logits, probs, k, batch_size, SIGMA, - vocab_size, BLOCK_SIZE, NUM_TILES) + vocab_size) elif k is None and p is not None: probs = torch.full_like(logits, -float('inf'), device=logits.device) probs_2 = torch.full_like(logits, -float('inf'), device=logits.device) _topp_kernel[(NUM_PROGRAMS, )](logits, probs, probs_2, p, batch_size, - SIGMA, vocab_size, BLOCK_SIZE, - NUM_TILES) + SIGMA, vocab_size) elif k is not None and p is not None: probs = torch.full_like(logits, -float('inf'), device=logits.device) _topk_topp_kernel[(NUM_PROGRAMS, )](logits, probs, k, p, batch_size, - SIGMA, vocab_size, BLOCK_SIZE, - NUM_TILES) + SIGMA, vocab_size) return logits +def triton_get_configs(): + return [ + triton.Config({'BLOCK_SIZE': B}, num_stages=s, num_warps=w) + for B in [4096, 8192, 16384] for s in [1, 2, 3, 4] for w in [4, 8, 16] + ] + + +@triton.autotune( + configs=triton_get_configs(), + key=['N'], +) @triton.jit def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, NUM_TILES: tl.constexpr): + BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) + NUM_TILES: tl.constexpr = (N + BLOCK_SIZE - 1) // BLOCK_SIZE for row_id in tl.range(pid, B, num_programs): k = tl.load(K + row_id) if k != N: # All tokens are valid @@ -385,10 +394,14 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) +@triton.autotune( + configs=triton_get_configs(), + key=['N'], +) @triton.jit def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, - N: tl.constexpr, BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr): + N: tl.constexpr, BLOCK_SIZE: tl.constexpr): + NUM_TILES: tl.constexpr = (N + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): @@ -576,10 +589,14 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) +@triton.autotune( + configs=triton_get_configs(), + key=['N'], +) @triton.jit def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, - N: tl.constexpr, BLOCK_SIZE: tl.constexpr, - NUM_TILES: tl.constexpr): + N: tl.constexpr, BLOCK_SIZE: tl.constexpr): + NUM_TILES: tl.constexpr = (N + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): From 340b6b468e90fcd6132f98958b1ab6379535956c Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 27 Sep 2025 18:42:45 -0700 Subject: [PATCH 29/99] Reduce diff and do fallback when batch size small. Signed-off-by: js_park --- test.py | 36 ----- test_triton_topk_topp.py | 179 ------------------------ vllm/v1/sample/ops/topk_topp_sampler.py | 64 +++++---- 3 files changed, 38 insertions(+), 241 deletions(-) delete mode 100644 test.py delete mode 100644 test_triton_topk_topp.py diff --git a/test.py b/test.py deleted file mode 100644 index 464a933c347a..000000000000 --- a/test.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os - -from vllm import LLM, SamplingParams - -# For V1: Turn off multiprocessing to make scheduling deterministic -os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" -os.environ["VLLM_USE_TRITON_SAMPLER"] = "1" - -# Set a fixed seed for reproducibility -SEED = 42 - -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -prompts = prompts * 64 - -# Add seed to sampling parameters for deterministic sampling -sampling_params = SamplingParams(temperature=0.8, top_p=0.7, seed=SEED) - -# Add seed to LLM initialization for global reproducibility -llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", seed=SEED) -# llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite") - -outputs = llm.generate(prompts, sampling_params) - -for i, output in enumerate(outputs): - if i > 4: - break - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/test_triton_topk_topp.py b/test_triton_topk_topp.py deleted file mode 100644 index 491c8dd9b175..000000000000 --- a/test_triton_topk_topp.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from datetime import datetime -from itertools import product - -import regex as re -import torch - -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - apply_top_k_top_p_triton) - - -def g_str(s): - return "\033[32m" + s + "\033[0m" - - -def r_str(s): - return "\033[31m" + s + "\033[0m" - - -def y_str(s): - return "\033[33m" + s + "\033[0m" - - -def b_str(s): - return "\033[34m" + s + "\033[0m" - - -def print_to_log(s, log_file): - print(s) - # Remove the color codes - s = re.sub(r"\033\[[0-9;]*m", "", s) - with open(log_file, "a") as f: - f.write(s + "\n") - - -def test_accuracy(logits, k, p): - input_logits = logits.clone() - original_logits = apply_top_k_top_p(input_logits, k, p) - logits = apply_top_k_top_p_triton(logits, k, p) - - torch.cuda.synchronize() - is_correct = torch.allclose(logits, original_logits) - - if not is_correct: - print_to_log(r_str("Error: logits are not close"), log_file) - error_mask = torch.abs(logits - original_logits) > 1e-5 - error_rows = torch.where(error_mask)[0] - error_rows = torch.unique(error_rows) - num_error_rows = error_rows.shape[0] - error_cols = torch.where(error_mask)[1] - error_cols = torch.unique(error_cols) - num_error_cols = error_cols.shape[0] - print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", - log_file) - print_to_log(f"num_error_cols: {num_error_cols}", log_file) - row_to_show = 5 if num_error_rows > 5 else num_error_rows - logits_to_show = torch.sort(logits[error_rows], descending=True).values - logits_to_show = logits_to_show[:row_to_show, :20] - print_to_log(f"logits: {logits_to_show}", log_file) - original_logits_to_show = \ - torch.sort(original_logits[error_rows], descending=True).values - original_logits_to_show = original_logits_to_show[:row_to_show, :20] - print_to_log(f"original_logits: {original_logits_to_show}", log_file) - - return is_correct - - -def test_time(logits, k, p, num_runs=30, num_warmup=5): - # We must clone the logits for each run to avoid modifying the original - warmup_tensor = logits.clone() - for _ in range(num_warmup): - apply_top_k_top_p(warmup_tensor, k, p) - apply_top_k_top_p_triton(warmup_tensor, k, p) - torch.cuda.synchronize() - - input_logits_torch = [logits.clone() for _ in range(num_runs)] - - torch.cuda.synchronize() - start_time = time.time() - for _ in range(num_runs): - input_logits_torch[_] = apply_top_k_top_p(input_logits_torch[_], k, p) - torch.cuda.synchronize() - torch_time_taken = (time.time() - start_time) / num_runs - - input_logits_triton = [logits.clone() for _ in range(num_runs)] - - torch.cuda.synchronize() - start_time = time.time() - for _ in range(num_runs): - input_logits_triton[_] = apply_top_k_top_p_triton( - input_logits_triton[_], k, p) - torch.cuda.synchronize() - triton_time_taken = (time.time() - start_time) / num_runs - - return torch_time_taken, triton_time_taken - - -if __name__ == "__main__": - date_str = datetime.now().strftime("%Y%m%d_%H%M%S") - # batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 - # vocab_size_list = [2**i for i in range(12, 18)] + [102400, 128256] - # p_list = [None, "RAND"] + [0.2 * i - # for i in range(1, 6)] + [0.9, 0.95, 0.99] - # k_list = [None, "RAND"] + [i for i in range(1, 11, 3) - # ] + [i for i in range(20, 230, 50)] - - batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 - vocab_size_list = [4096, 32768, 131072, 102400, 128256] - p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] - k_list = [None, "RAND", 5, 10, 50, 100, 200] - - log_file = f"triton_topk_topp_test_{date_str}.log" - csv_file = f"triton_topk_topp_test_{date_str}.csv" - - print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) - print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) - print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) - print_to_log(y_str("p_list:") + f"{p_list}", log_file) - print_to_log(y_str("k_list:") + f"{k_list}", log_file) - print_to_log(y_str("log_file:") + f"{log_file}", log_file) - print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) - - with open(csv_file, "w") as f: - f.write("dist_generator,batch_size,vocab_size,p,k,is_correct," - "torch_time_taken,triton_time_taken,speedup\n") - - for batch_size, vocab_size, p, k in product(batch_size_list, - vocab_size_list, p_list, - k_list): - if p == "RAND" and k == "RAND": - continue - - logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 - logits_list = [("RANDN", logits_randn)] - - if p == "RAND": - p_tensor = torch.rand((batch_size, ), device="cuda") * 0.9 + 0.05 - elif p is not None: - p_tensor = torch.full((batch_size, ), p, device="cuda") - else: - p_tensor = None - - if k == "RAND": - k_tensor = torch.randint(1, 300, (batch_size, ), device="cuda") - elif k is not None: - k_tensor = torch.full((batch_size, ), k, device="cuda") - else: - k_tensor = None - - for dist_generator, logits in logits_list: - print_to_log(y_str("--------------------------------"), log_file) - print_to_log( - g_str("Testing ") + f"{dist_generator}" + - y_str(" with batch_size: ") + f"{batch_size}" + - y_str(" vocab_size: ") + f"{vocab_size}" + y_str(" p: ") + - f"{p}" + y_str(" k: ") + f"{k}", log_file) - is_correct = test_accuracy(logits, k_tensor, p_tensor) - if not is_correct: - print_to_log( - f"Error: logits are not close for batch_size: {batch_size}," - f" vocab_size: {vocab_size}, dist_generator: " - f"{dist_generator}, p: {p}, k: {k}", log_file) - torch_time_taken, triton_time_taken = test_time( - logits, k_tensor, p_tensor) - print_to_log( - b_str("torch_time_taken: ") + f"{torch_time_taken}", log_file) - print_to_log( - b_str("triton_time_taken: ") + f"{triton_time_taken}", - log_file) - print_to_log( - g_str("Triton Speedup over Torch: ") + - f"{torch_time_taken / triton_time_taken:.8f}x", log_file) - with open(csv_file, "a") as f: - f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k}," - f"{is_correct},{torch_time_taken},{triton_time_taken}," - f"{torch_time_taken / triton_time_taken:.8f}\n") - print_to_log(y_str("--------------------------------\n"), log_file) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 6983cbc97960..cdfae8f1319e 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -116,6 +116,12 @@ def forward_triton( k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + + if logits.shape[0] < 32: + # Curreny Triton implementation is not optimized for small batch + # sizes, as it launches a single program for a single batch. + return self.forward_native(logits, generators, k, p) + logits = self.apply_top_k_top_p_triton(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": @@ -138,9 +144,9 @@ def forward_cuda( # CPU-GPU synchronization while `flashinfer_sample` does. if (k is None and p is None) or generators: if generators: - logger.debug_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + logger.warning_once("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") return self.forward_native(logits, generators, k, p) assert self.logprobs_mode not in ( "processed_logits", "processed_logprobs" @@ -196,35 +202,40 @@ def apply_top_k_top_p( p: Optional[torch.Tensor], ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. - """ - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. + """ if p is None: if k is None: return logits # Avoid sorting vocab for top-k only case. - logits = apply_top_k_only(logits, k) - else: - if k is not None: - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - if p is not None: - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:12, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return apply_top_k_only(logits, k) + + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if k is not None: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) return logits @@ -235,6 +246,7 @@ def apply_top_k_top_p_triton( ) -> torch.Tensor: batch_size, vocab_size = logits.shape + device_prop = torch.cuda.get_device_properties(logits.device) NUM_PROGRAMS = device_prop.multi_processor_count SIGMA = 2.15 # Top 0.03 outliers From cf768c22894392618ef5a3f3c1748d113f9b837a Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 27 Sep 2025 22:19:56 -0700 Subject: [PATCH 30/99] Test script fix Signed-off-by: js_park --- test_triton_topk_topp.py | 182 ++++++++++++++++++++++++ vllm/v1/sample/ops/topk_topp_sampler.py | 20 +-- 2 files changed, 192 insertions(+), 10 deletions(-) create mode 100644 test_triton_topk_topp.py diff --git a/test_triton_topk_topp.py b/test_triton_topk_topp.py new file mode 100644 index 000000000000..f425f87834ed --- /dev/null +++ b/test_triton_topk_topp.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from datetime import datetime +from itertools import product + +import regex as re +import torch + +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + apply_top_k_top_p_triton) + + +def g_str(s): + return "\033[32m" + s + "\033[0m" + + +def r_str(s): + return "\033[31m" + s + "\033[0m" + + +def y_str(s): + return "\033[33m" + s + "\033[0m" + + +def b_str(s): + return "\033[34m" + s + "\033[0m" + + +def print_to_log(s, log_file): + print(s) + # Remove the color codes + s = re.sub(r"\033\[[0-9;]*m", "", s) + with open(log_file, "a") as f: + f.write(s + "\n") + + +def test_accuracy(logits, k, p): + input_logits_torch = logits.clone().detach() + input_logits_triton = logits.clone().detach() + original_logits = apply_top_k_top_p(input_logits_torch, k, p) + triton_logits = apply_top_k_top_p_triton(input_logits_triton, k, p) + + torch.cuda.synchronize() + is_correct = torch.allclose(triton_logits, original_logits) + + if not is_correct: + print_to_log(r_str("Error: logits are not close"), log_file) + error_mask = torch.abs(triton_logits - original_logits) > 1e-5 + error_rows = torch.where(error_mask)[0] + error_rows = torch.unique(error_rows) + num_error_rows = error_rows.shape[0] + error_cols = torch.where(error_mask)[1] + error_cols = torch.unique(error_cols) + num_error_cols = error_cols.shape[0] + print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", + log_file) + print_to_log(f"num_error_cols: {num_error_cols}", log_file) + row_to_show = 5 if num_error_rows > 5 else num_error_rows + logits_to_show = torch.sort(triton_logits[error_rows], + descending=True).values + logits_to_show = logits_to_show[:row_to_show, :20] + print_to_log(f"logits: {logits_to_show}", log_file) + original_logits_to_show = \ + torch.sort(original_logits[error_rows], descending=True).values + original_logits_to_show = original_logits_to_show[:row_to_show, :20] + print_to_log(f"original_logits: {original_logits_to_show}", log_file) + + return is_correct + + +def test_time(logits, k, p, num_runs=30, num_warmup=5): + # We must clone the logits for each run to avoid modifying the original + warmup_tensor = logits.clone().detach() + for _ in range(num_warmup): + apply_top_k_top_p(warmup_tensor, k, p) + apply_top_k_top_p_triton(warmup_tensor, k, p) + torch.cuda.synchronize() + + input_logits_torch = [logits.clone().detach() for _ in range(num_runs)] + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + for _ in range(num_runs): + input_logits_torch[_] = apply_top_k_top_p(input_logits_torch[_], k, p) + end.record() + torch.cuda.synchronize() + torch_time_taken = start.elapsed_time(end) / num_runs + + input_logits_triton = [logits.clone().detach() for _ in range(num_runs)] + + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(num_runs): + input_logits_triton[_] = apply_top_k_top_p_triton( + input_logits_triton[_], k, p) + end.record() + torch.cuda.synchronize() + triton_time_taken = start.elapsed_time(end) / num_runs + + return torch_time_taken, triton_time_taken + + +if __name__ == "__main__": + date_str = datetime.now().strftime("%Y%m%d_%H%M%S") + + batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 + vocab_size_list = [4096, 16384, 65536, 262144, 102400, 128256] + p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] + k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + + log_file = f"triton_topk_topp_test_{date_str}.log" + csv_file = f"triton_topk_topp_test_{date_str}.csv" + + print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) + print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) + print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) + print_to_log(y_str("p_list:") + f"{p_list}", log_file) + print_to_log(y_str("k_list:") + f"{k_list}", log_file) + print_to_log(y_str("log_file:") + f"{log_file}", log_file) + print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) + + with open(csv_file, "w") as f: + f.write("dist_generator,batch_size,vocab_size,p,k,is_correct," + "torch_time_taken,triton_time_taken,speedup\n") + + for batch_size, vocab_size, p, k in product(batch_size_list, + vocab_size_list, p_list, + k_list): + if p is None and k is None: + continue + + logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 + logits_list = [("RANDN", logits_randn)] + + if p == "RAND": + p_tensor = torch.rand((batch_size, ), device="cuda") * 0.95 + 0.05 + elif p is not None: + p_tensor = torch.full((batch_size, ), p, device="cuda") + else: + p_tensor = None + + if k == "RAND": + k_tensor = torch.randint(1, + vocab_size, (batch_size, ), + device="cuda") + elif k is not None: + k_tensor = torch.full((batch_size, ), k, device="cuda") + else: + k_tensor = None + + for dist_generator, logits in logits_list: + print_to_log(y_str("--------------------------------"), log_file) + print_to_log( + g_str("Testing ") + f"{dist_generator}" + + y_str(" with batch_size: ") + f"{batch_size}" + + y_str(" vocab_size: ") + f"{vocab_size}" + y_str(" p: ") + + f"{p}" + y_str(" k: ") + f"{k}", log_file) + is_correct = test_accuracy(logits, k_tensor, p_tensor) + if not is_correct: + print_to_log( + f"Error: logits are not close for batch_size: {batch_size}," + f" vocab_size: {vocab_size}, dist_generator: " + f"{dist_generator}, p: {p}, k: {k}", log_file) + torch_time_taken, triton_time_taken = test_time( + logits, k_tensor, p_tensor) + print_to_log( + b_str("torch_time_taken: ") + f"{torch_time_taken}", log_file) + print_to_log( + b_str("triton_time_taken: ") + f"{triton_time_taken}", + log_file) + print_to_log( + g_str("Triton Speedup over Torch: ") + + f"{torch_time_taken / triton_time_taken:.8f}x", log_file) + with open(csv_file, "a") as f: + f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k}," + f"{is_correct},{torch_time_taken},{triton_time_taken}," + f"{torch_time_taken / triton_time_taken:.8f}\n") + print_to_log(y_str("--------------------------------\n"), log_file) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index cdfae8f1319e..81b313a19d10 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -662,6 +662,16 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, ############### START OF TOP-K CODE ############### k = tl.load(K + row_id) + max_range = max_logit + min_range = min_logit + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_addr = PROBS_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + if k != N: # All tokens are valid # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K @@ -672,16 +682,6 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, # IMPLEMENTATION AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT # USING THE FORCE_REMOVE_LOGIT VARIABLE. - max_range = max_logit - min_range = min_logit - if num_outliers > k: - max_range = max_logit - min_range = outlier_pivot - search_addr = PROBS_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 while k_pivot == -float('inf') and num_iters < 18: From 9b3cf75aacde92b8978aad3752c6a54ff0b60e4a Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 27 Sep 2025 23:06:51 -0700 Subject: [PATCH 31/99] Added graph generation Signed-off-by: js_park --- graph.py | 272 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 graph.py diff --git a/graph.py b/graph.py new file mode 100644 index 000000000000..a247b9599666 --- /dev/null +++ b/graph.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import matplotlib.pyplot as plt +import pandas as pd + +input_path = './triton_topk_topp_test_20250927_222024.csv' + + +def load_and_parse_data(csv_file): + """Load CSV data and parse it into a structured format.""" + df = pd.read_csv(csv_file, dtype={'p': str, 'k': str}) + print(df.head()) + print(df.columns) + print(df.info()) + print(df.describe()) + print(df.isnull().sum()) + print(df.duplicated().sum()) + print(df.shape) + print(df.head()) + return df + + +def get_filtered_data(df, vocab_size, p_val, k_val): + """Filter data for specific vocab_size, p, and k values.""" + # Handle None values properly + if p_val is None: + p_condition = df['p'] == "NO" + else: + p_condition = df['p'] == str(p_val) + + if k_val is None: + k_condition = df['k'] == "NO" + else: + k_condition = df['k'] == str(k_val) + + filtered_df = df[(df['vocab_size'] == vocab_size) & p_condition + & k_condition].copy() + + return filtered_df.sort_values('batch_size') + + +def create_speedup_plots(): + """Create 4x4 grid of speedup vs batch size plots.""" + # Load data + csv_file = input_path + df = load_and_parse_data(csv_file) + + # Define the parameter combinations for columns + column_configs = [{ + 'p': None, + 'k': 200, + 'title': 'P=None, K=200' + }, { + 'p': 0.9, + 'k': None, + 'title': 'P=0.9, K=None' + }, { + 'p': 0.9, + 'k': 200, + 'title': 'P=0.9, K=200' + }, { + 'p': 'RAND', + 'k': 'RAND', + 'title': 'P=RAND, K=RAND' + }] + + # Define vocab sizes for rows + vocab_sizes = [4096, 65536, 102400, 128256] + + # We'll calculate y-axis limits per subplot now + + # Create figure with subplots + fig, axes = plt.subplots(4, 4, figsize=(20, 16)) + fig.suptitle('Speedup vs Batch Size', fontsize=20, fontweight='bold') + + # Plot each combination + for row, vocab_size in enumerate(vocab_sizes): + for col, config in enumerate(column_configs): + ax = axes[row, col] + + # Get filtered data + data = get_filtered_data(df, vocab_size, config['p'], config['k']) + + if not data.empty: + # Calculate y-axis limit for this specific subplot + local_max_speedup = data['speedup'].max() + local_y_max = local_max_speedup * 1.1 if local_max_speedup > 0 else 10.0 + + # Plot speedup vs batch size + ax.plot(data['batch_size'], + data['speedup'], + 'bo-', + linewidth=2, + markersize=6) + ax.set_xscale('log', base=2) + ax.set_ylim( + 0.0, local_y_max) # Set y-axis range from 0 to local max + ax.grid(True, alpha=0.3) + + # Add horizontal line at speedup=1 + ax.axhline(y=1, + color='red', + linestyle='--', + linewidth=2, + alpha=0.7, + label='Speedup=1') + + # Set labels and title + if row == 3: # Bottom row + ax.set_xlabel('Batch Size', fontsize=12) + if col == 0: # Left column + ax.set_ylabel('Speedup', fontsize=12) + + # Set title for top row + if row == 0: + ax.set_title(config['title'], + fontsize=14, + fontweight='bold') + + # Add vocab size label on the left + if col == 0: + vocab_size_str = f'Vocab Size {vocab_size}' + + ax.text(-0.2, + 0.5, + vocab_size_str, + transform=ax.transAxes, + fontsize=14, + fontweight='bold', + ha='center', + va='center', + rotation=90) + + # Set reasonable axis limits + if len(data) > 0: + batch_sizes = data['batch_size'].values + + ax.set_xlim(batch_sizes.min() * 0.8, + batch_sizes.max() * 1.2) + # Y-axis is already set to 0-10 above + + # Format x-axis ticks + ax.set_xticks( + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + ax.set_xticklabels( + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, ""]) + + # Add legend only to the first subplot + if row == 0 and col == 0: + ax.legend(loc='upper left') + + else: + # No data available - use default y-axis range + default_y_max = 10.0 + ax.text(0.5, + 0.5, + 'No Data\nAvailable', + transform=ax.transAxes, + fontsize=12, + ha='center', + va='center', + color='red') + ax.set_xlim(1, 2048) + ax.set_ylim( + 0.0, default_y_max) # Set y-axis range from 0 to default + ax.set_xscale('log', base=2) + ax.grid(True, alpha=0.3) + + # Add horizontal line at speedup=1 + ax.axhline(y=1, + color='red', + linestyle='--', + linewidth=2, + alpha=0.7) + + if row == 3: # Bottom row + ax.set_xlabel('Batch Size', fontsize=12) + if col == 0: # Left column + ax.set_ylabel('Speedup', fontsize=12) + if row == 0: + ax.set_title(config['title'], + fontsize=12, + fontweight='bold') + if col == 0: + ax.text(-0.2, + 0.5, + f'Vocab Size {vocab_size}', + transform=ax.transAxes, + fontsize=14, + fontweight='bold', + ha='center', + va='center', + rotation=90) + + # Adjust layout + plt.tight_layout() + plt.subplots_adjust(top=0.93, left=0.08) + + # Save the plot + output_file = './speedup_analysis.png' + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"Speedup analysis plot saved to: {output_file}") + + # Show the plot + plt.show() + + return fig + + +def print_data_summary(): + """Print a summary of the available data.""" + csv_file = input_path + df = load_and_parse_data(csv_file) + + print("Data Summary:") + print(f"Total rows: {len(df)}") + print(f"Unique batch sizes: {sorted(df['batch_size'].unique())}") + print(f"Unique vocab sizes: {sorted(df['vocab_size'].unique())}") + print( + f"Unique p values: {sorted([p for p in df['p'].unique() if p != 'nan'])}" + ) + print( + f"Unique k values: {sorted([k for k in df['k'].unique() if k != 'nan'])}" + ) + print() + + # Check data availability for each configuration + column_configs = [{ + 'p': None, + 'k': 200, + 'title': 'P=None, K=200' + }, { + 'p': 0.9, + 'k': None, + 'title': 'P=0.9, K=None' + }, { + 'p': None, + 'k': 200, + 'title': 'P=0.9, K=200' + }, { + 'p': "RAND", + 'k': "RAND", + 'title': 'P=RAND, K=RAND' + }] + vocab_sizes = [4096, 32768, 102400, 128256] + + print("Data availability matrix:") + print("Rows: Vocab sizes, Columns: Parameter combinations") + print("Values: Number of data points available") + print() + + header = f"{'Vocab Size':<12}" + for config in column_configs: + header += f"{config['title']:<15}" + print(header) + print("-" * len(header)) + + for vocab_size in vocab_sizes: + row = f"{vocab_size:<12}" + for config in column_configs: + data = get_filtered_data(df, vocab_size, config['p'], config['k']) + row += f"{len(data):<15}" + print(row) + + +if __name__ == "__main__": + # Print data summary first + print_data_summary() + print("\n" + "=" * 80 + "\n") + + # Create the plots + create_speedup_plots() From 4235295a2130cd703fa1023f401acdd01057ec50 Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 27 Sep 2025 23:07:37 -0700 Subject: [PATCH 32/99] Removed fallback Signed-off-by: js_park --- graph.py | 272 ------------------------ test_triton_topk_topp.py | 182 ---------------- vllm/v1/sample/ops/topk_topp_sampler.py | 5 - 3 files changed, 459 deletions(-) delete mode 100644 graph.py delete mode 100644 test_triton_topk_topp.py diff --git a/graph.py b/graph.py deleted file mode 100644 index a247b9599666..000000000000 --- a/graph.py +++ /dev/null @@ -1,272 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import matplotlib.pyplot as plt -import pandas as pd - -input_path = './triton_topk_topp_test_20250927_222024.csv' - - -def load_and_parse_data(csv_file): - """Load CSV data and parse it into a structured format.""" - df = pd.read_csv(csv_file, dtype={'p': str, 'k': str}) - print(df.head()) - print(df.columns) - print(df.info()) - print(df.describe()) - print(df.isnull().sum()) - print(df.duplicated().sum()) - print(df.shape) - print(df.head()) - return df - - -def get_filtered_data(df, vocab_size, p_val, k_val): - """Filter data for specific vocab_size, p, and k values.""" - # Handle None values properly - if p_val is None: - p_condition = df['p'] == "NO" - else: - p_condition = df['p'] == str(p_val) - - if k_val is None: - k_condition = df['k'] == "NO" - else: - k_condition = df['k'] == str(k_val) - - filtered_df = df[(df['vocab_size'] == vocab_size) & p_condition - & k_condition].copy() - - return filtered_df.sort_values('batch_size') - - -def create_speedup_plots(): - """Create 4x4 grid of speedup vs batch size plots.""" - # Load data - csv_file = input_path - df = load_and_parse_data(csv_file) - - # Define the parameter combinations for columns - column_configs = [{ - 'p': None, - 'k': 200, - 'title': 'P=None, K=200' - }, { - 'p': 0.9, - 'k': None, - 'title': 'P=0.9, K=None' - }, { - 'p': 0.9, - 'k': 200, - 'title': 'P=0.9, K=200' - }, { - 'p': 'RAND', - 'k': 'RAND', - 'title': 'P=RAND, K=RAND' - }] - - # Define vocab sizes for rows - vocab_sizes = [4096, 65536, 102400, 128256] - - # We'll calculate y-axis limits per subplot now - - # Create figure with subplots - fig, axes = plt.subplots(4, 4, figsize=(20, 16)) - fig.suptitle('Speedup vs Batch Size', fontsize=20, fontweight='bold') - - # Plot each combination - for row, vocab_size in enumerate(vocab_sizes): - for col, config in enumerate(column_configs): - ax = axes[row, col] - - # Get filtered data - data = get_filtered_data(df, vocab_size, config['p'], config['k']) - - if not data.empty: - # Calculate y-axis limit for this specific subplot - local_max_speedup = data['speedup'].max() - local_y_max = local_max_speedup * 1.1 if local_max_speedup > 0 else 10.0 - - # Plot speedup vs batch size - ax.plot(data['batch_size'], - data['speedup'], - 'bo-', - linewidth=2, - markersize=6) - ax.set_xscale('log', base=2) - ax.set_ylim( - 0.0, local_y_max) # Set y-axis range from 0 to local max - ax.grid(True, alpha=0.3) - - # Add horizontal line at speedup=1 - ax.axhline(y=1, - color='red', - linestyle='--', - linewidth=2, - alpha=0.7, - label='Speedup=1') - - # Set labels and title - if row == 3: # Bottom row - ax.set_xlabel('Batch Size', fontsize=12) - if col == 0: # Left column - ax.set_ylabel('Speedup', fontsize=12) - - # Set title for top row - if row == 0: - ax.set_title(config['title'], - fontsize=14, - fontweight='bold') - - # Add vocab size label on the left - if col == 0: - vocab_size_str = f'Vocab Size {vocab_size}' - - ax.text(-0.2, - 0.5, - vocab_size_str, - transform=ax.transAxes, - fontsize=14, - fontweight='bold', - ha='center', - va='center', - rotation=90) - - # Set reasonable axis limits - if len(data) > 0: - batch_sizes = data['batch_size'].values - - ax.set_xlim(batch_sizes.min() * 0.8, - batch_sizes.max() * 1.2) - # Y-axis is already set to 0-10 above - - # Format x-axis ticks - ax.set_xticks( - [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) - ax.set_xticklabels( - [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, ""]) - - # Add legend only to the first subplot - if row == 0 and col == 0: - ax.legend(loc='upper left') - - else: - # No data available - use default y-axis range - default_y_max = 10.0 - ax.text(0.5, - 0.5, - 'No Data\nAvailable', - transform=ax.transAxes, - fontsize=12, - ha='center', - va='center', - color='red') - ax.set_xlim(1, 2048) - ax.set_ylim( - 0.0, default_y_max) # Set y-axis range from 0 to default - ax.set_xscale('log', base=2) - ax.grid(True, alpha=0.3) - - # Add horizontal line at speedup=1 - ax.axhline(y=1, - color='red', - linestyle='--', - linewidth=2, - alpha=0.7) - - if row == 3: # Bottom row - ax.set_xlabel('Batch Size', fontsize=12) - if col == 0: # Left column - ax.set_ylabel('Speedup', fontsize=12) - if row == 0: - ax.set_title(config['title'], - fontsize=12, - fontweight='bold') - if col == 0: - ax.text(-0.2, - 0.5, - f'Vocab Size {vocab_size}', - transform=ax.transAxes, - fontsize=14, - fontweight='bold', - ha='center', - va='center', - rotation=90) - - # Adjust layout - plt.tight_layout() - plt.subplots_adjust(top=0.93, left=0.08) - - # Save the plot - output_file = './speedup_analysis.png' - plt.savefig(output_file, dpi=300, bbox_inches='tight') - print(f"Speedup analysis plot saved to: {output_file}") - - # Show the plot - plt.show() - - return fig - - -def print_data_summary(): - """Print a summary of the available data.""" - csv_file = input_path - df = load_and_parse_data(csv_file) - - print("Data Summary:") - print(f"Total rows: {len(df)}") - print(f"Unique batch sizes: {sorted(df['batch_size'].unique())}") - print(f"Unique vocab sizes: {sorted(df['vocab_size'].unique())}") - print( - f"Unique p values: {sorted([p for p in df['p'].unique() if p != 'nan'])}" - ) - print( - f"Unique k values: {sorted([k for k in df['k'].unique() if k != 'nan'])}" - ) - print() - - # Check data availability for each configuration - column_configs = [{ - 'p': None, - 'k': 200, - 'title': 'P=None, K=200' - }, { - 'p': 0.9, - 'k': None, - 'title': 'P=0.9, K=None' - }, { - 'p': None, - 'k': 200, - 'title': 'P=0.9, K=200' - }, { - 'p': "RAND", - 'k': "RAND", - 'title': 'P=RAND, K=RAND' - }] - vocab_sizes = [4096, 32768, 102400, 128256] - - print("Data availability matrix:") - print("Rows: Vocab sizes, Columns: Parameter combinations") - print("Values: Number of data points available") - print() - - header = f"{'Vocab Size':<12}" - for config in column_configs: - header += f"{config['title']:<15}" - print(header) - print("-" * len(header)) - - for vocab_size in vocab_sizes: - row = f"{vocab_size:<12}" - for config in column_configs: - data = get_filtered_data(df, vocab_size, config['p'], config['k']) - row += f"{len(data):<15}" - print(row) - - -if __name__ == "__main__": - # Print data summary first - print_data_summary() - print("\n" + "=" * 80 + "\n") - - # Create the plots - create_speedup_plots() diff --git a/test_triton_topk_topp.py b/test_triton_topk_topp.py deleted file mode 100644 index f425f87834ed..000000000000 --- a/test_triton_topk_topp.py +++ /dev/null @@ -1,182 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from datetime import datetime -from itertools import product - -import regex as re -import torch - -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - apply_top_k_top_p_triton) - - -def g_str(s): - return "\033[32m" + s + "\033[0m" - - -def r_str(s): - return "\033[31m" + s + "\033[0m" - - -def y_str(s): - return "\033[33m" + s + "\033[0m" - - -def b_str(s): - return "\033[34m" + s + "\033[0m" - - -def print_to_log(s, log_file): - print(s) - # Remove the color codes - s = re.sub(r"\033\[[0-9;]*m", "", s) - with open(log_file, "a") as f: - f.write(s + "\n") - - -def test_accuracy(logits, k, p): - input_logits_torch = logits.clone().detach() - input_logits_triton = logits.clone().detach() - original_logits = apply_top_k_top_p(input_logits_torch, k, p) - triton_logits = apply_top_k_top_p_triton(input_logits_triton, k, p) - - torch.cuda.synchronize() - is_correct = torch.allclose(triton_logits, original_logits) - - if not is_correct: - print_to_log(r_str("Error: logits are not close"), log_file) - error_mask = torch.abs(triton_logits - original_logits) > 1e-5 - error_rows = torch.where(error_mask)[0] - error_rows = torch.unique(error_rows) - num_error_rows = error_rows.shape[0] - error_cols = torch.where(error_mask)[1] - error_cols = torch.unique(error_cols) - num_error_cols = error_cols.shape[0] - print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", - log_file) - print_to_log(f"num_error_cols: {num_error_cols}", log_file) - row_to_show = 5 if num_error_rows > 5 else num_error_rows - logits_to_show = torch.sort(triton_logits[error_rows], - descending=True).values - logits_to_show = logits_to_show[:row_to_show, :20] - print_to_log(f"logits: {logits_to_show}", log_file) - original_logits_to_show = \ - torch.sort(original_logits[error_rows], descending=True).values - original_logits_to_show = original_logits_to_show[:row_to_show, :20] - print_to_log(f"original_logits: {original_logits_to_show}", log_file) - - return is_correct - - -def test_time(logits, k, p, num_runs=30, num_warmup=5): - # We must clone the logits for each run to avoid modifying the original - warmup_tensor = logits.clone().detach() - for _ in range(num_warmup): - apply_top_k_top_p(warmup_tensor, k, p) - apply_top_k_top_p_triton(warmup_tensor, k, p) - torch.cuda.synchronize() - - input_logits_torch = [logits.clone().detach() for _ in range(num_runs)] - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - for _ in range(num_runs): - input_logits_torch[_] = apply_top_k_top_p(input_logits_torch[_], k, p) - end.record() - torch.cuda.synchronize() - torch_time_taken = start.elapsed_time(end) / num_runs - - input_logits_triton = [logits.clone().detach() for _ in range(num_runs)] - - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(num_runs): - input_logits_triton[_] = apply_top_k_top_p_triton( - input_logits_triton[_], k, p) - end.record() - torch.cuda.synchronize() - triton_time_taken = start.elapsed_time(end) / num_runs - - return torch_time_taken, triton_time_taken - - -if __name__ == "__main__": - date_str = datetime.now().strftime("%Y%m%d_%H%M%S") - - batch_size_list = [2**i for i in range(0, 11)] # 1 to 1024 - vocab_size_list = [4096, 16384, 65536, 262144, 102400, 128256] - p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] - k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] - - log_file = f"triton_topk_topp_test_{date_str}.log" - csv_file = f"triton_topk_topp_test_{date_str}.csv" - - print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) - print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) - print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) - print_to_log(y_str("p_list:") + f"{p_list}", log_file) - print_to_log(y_str("k_list:") + f"{k_list}", log_file) - print_to_log(y_str("log_file:") + f"{log_file}", log_file) - print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) - - with open(csv_file, "w") as f: - f.write("dist_generator,batch_size,vocab_size,p,k,is_correct," - "torch_time_taken,triton_time_taken,speedup\n") - - for batch_size, vocab_size, p, k in product(batch_size_list, - vocab_size_list, p_list, - k_list): - if p is None and k is None: - continue - - logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 - logits_list = [("RANDN", logits_randn)] - - if p == "RAND": - p_tensor = torch.rand((batch_size, ), device="cuda") * 0.95 + 0.05 - elif p is not None: - p_tensor = torch.full((batch_size, ), p, device="cuda") - else: - p_tensor = None - - if k == "RAND": - k_tensor = torch.randint(1, - vocab_size, (batch_size, ), - device="cuda") - elif k is not None: - k_tensor = torch.full((batch_size, ), k, device="cuda") - else: - k_tensor = None - - for dist_generator, logits in logits_list: - print_to_log(y_str("--------------------------------"), log_file) - print_to_log( - g_str("Testing ") + f"{dist_generator}" + - y_str(" with batch_size: ") + f"{batch_size}" + - y_str(" vocab_size: ") + f"{vocab_size}" + y_str(" p: ") + - f"{p}" + y_str(" k: ") + f"{k}", log_file) - is_correct = test_accuracy(logits, k_tensor, p_tensor) - if not is_correct: - print_to_log( - f"Error: logits are not close for batch_size: {batch_size}," - f" vocab_size: {vocab_size}, dist_generator: " - f"{dist_generator}, p: {p}, k: {k}", log_file) - torch_time_taken, triton_time_taken = test_time( - logits, k_tensor, p_tensor) - print_to_log( - b_str("torch_time_taken: ") + f"{torch_time_taken}", log_file) - print_to_log( - b_str("triton_time_taken: ") + f"{triton_time_taken}", - log_file) - print_to_log( - g_str("Triton Speedup over Torch: ") + - f"{torch_time_taken / triton_time_taken:.8f}x", log_file) - with open(csv_file, "a") as f: - f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k}," - f"{is_correct},{torch_time_taken},{triton_time_taken}," - f"{torch_time_taken / triton_time_taken:.8f}\n") - print_to_log(y_str("--------------------------------\n"), log_file) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 81b313a19d10..a25d88bcd009 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -117,11 +117,6 @@ def forward_triton( p: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if logits.shape[0] < 32: - # Curreny Triton implementation is not optimized for small batch - # sizes, as it launches a single program for a single batch. - return self.forward_native(logits, generators, k, p) - logits = self.apply_top_k_top_p_triton(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": From 344c3e4a84e8fbb0e8968df11784f4bd7caf2f93 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 28 Sep 2025 00:19:58 -0700 Subject: [PATCH 33/99] Added Gemini's suggestions, removed triton autotune. Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 47 +++++++++---------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index a25d88bcd009..1f372f61cc0d 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -244,6 +244,7 @@ def apply_top_k_top_p_triton( device_prop = torch.cuda.get_device_properties(logits.device) NUM_PROGRAMS = device_prop.multi_processor_count + BLOCK_SIZE = 4096 SIGMA = 2.15 # Top 0.03 outliers if k is not None and p is None: @@ -251,30 +252,19 @@ def apply_top_k_top_p_triton( -float('inf'), device=logits.device) _topk_kernel[(NUM_PROGRAMS, )](logits, probs, k, batch_size, SIGMA, - vocab_size) + vocab_size, BLOCK_SIZE) elif k is None and p is not None: probs = torch.full_like(logits, -float('inf'), device=logits.device) probs_2 = torch.full_like(logits, -float('inf'), device=logits.device) _topp_kernel[(NUM_PROGRAMS, )](logits, probs, probs_2, p, batch_size, - SIGMA, vocab_size) + SIGMA, vocab_size, BLOCK_SIZE) elif k is not None and p is not None: probs = torch.full_like(logits, -float('inf'), device=logits.device) _topk_topp_kernel[(NUM_PROGRAMS, )](logits, probs, k, p, batch_size, - SIGMA, vocab_size) + SIGMA, vocab_size, BLOCK_SIZE) return logits -def triton_get_configs(): - return [ - triton.Config({'BLOCK_SIZE': B}, num_stages=s, num_warps=w) - for B in [4096, 8192, 16384] for s in [1, 2, 3, 4] for w in [4, 8, 16] - ] - - -@triton.autotune( - configs=triton_get_configs(), - key=['N'], -) @triton.jit def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr): @@ -387,7 +377,7 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, max_range = k_pivot_2 num_iters += 1 - if num_iters >= 18: + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-8: k_pivot = k_pivot_0 # Third pass: Apply top-k mask @@ -401,10 +391,6 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) -@triton.autotune( - configs=triton_get_configs(), - key=['N'], -) @triton.jit def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr): @@ -418,8 +404,8 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, p_pivot = -float('inf') LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + row_id * N - PROBS_2_ROW = PROBS_2 + row_id * N + PROBS_ROW = PROBS + pid * N + PROBS_2_ROW = PROBS_2 + pid * N search_addr = PROBS_ROW search_range = N @@ -579,8 +565,11 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, other=-float('inf')) if force_remove_logit != -float('inf'): - force_remove_mask = tl.abs(logits_blk - - force_remove_logit) < 1e-5 + # Force remove duplicates + tolerance = 1e-5 * tl.maximum( + 1.0, tl.abs(force_remove_logit)) + force_remove_mask = tl.abs( + logits_blk - force_remove_logit) < tolerance force_remove_count = tl.cumsum( force_remove_mask) + current_num_force_remove force_remove_count_mask = \ @@ -596,10 +585,6 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) -@triton.autotune( - configs=triton_get_configs(), - key=['N'], -) @triton.jit def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr): @@ -611,7 +596,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, p_pivot = -float('inf') LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + row_id * N + PROBS_ROW = PROBS + pid * N search_addr = LOGITS_ROW search_range = N @@ -720,7 +705,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, max_range = k_pivot_2 num_iters += 1 - if num_iters >= 18: + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-8: k_pivot = k_pivot_0 ############### END OF TOP-K CODE ############### @@ -847,8 +832,10 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, if force_remove_logit != -float('inf'): # Force remove duplicates + tolerance = 1e-5 * tl.maximum(1.0, + tl.abs(force_remove_logit)) force_remove_mask = tl.abs(logits_blk - - force_remove_logit) < 1e-5 + force_remove_logit) < tolerance force_remove_count = tl.cumsum( force_remove_mask) + current_num_force_remove force_remove_count_mask = \ From da1b1e6f8f959fb460d899cb0761900c7023bb2d Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 28 Sep 2025 00:52:16 -0700 Subject: [PATCH 34/99] Fixed warps and stages Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 39 ++++++++++++++++++++----- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 1f372f61cc0d..01ad9d23e1c7 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -244,24 +244,49 @@ def apply_top_k_top_p_triton( device_prop = torch.cuda.get_device_properties(logits.device) NUM_PROGRAMS = device_prop.multi_processor_count - BLOCK_SIZE = 4096 + BLOCK_SIZE = 16384 SIGMA = 2.15 # Top 0.03 outliers + NUM_WARPS = 16 + NUM_STAGES = 3 if k is not None and p is None: probs = torch.full((NUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) - _topk_kernel[(NUM_PROGRAMS, )](logits, probs, k, batch_size, SIGMA, - vocab_size, BLOCK_SIZE) + _topk_kernel[(NUM_PROGRAMS, )](logits, + probs, + k, + batch_size, + SIGMA, + vocab_size, + BLOCK_SIZE, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES) elif k is None and p is not None: probs = torch.full_like(logits, -float('inf'), device=logits.device) probs_2 = torch.full_like(logits, -float('inf'), device=logits.device) - _topp_kernel[(NUM_PROGRAMS, )](logits, probs, probs_2, p, batch_size, - SIGMA, vocab_size, BLOCK_SIZE) + _topp_kernel[(NUM_PROGRAMS, )](logits, + probs, + probs_2, + p, + batch_size, + SIGMA, + vocab_size, + BLOCK_SIZE, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES) elif k is not None and p is not None: probs = torch.full_like(logits, -float('inf'), device=logits.device) - _topk_topp_kernel[(NUM_PROGRAMS, )](logits, probs, k, p, batch_size, - SIGMA, vocab_size, BLOCK_SIZE) + _topk_topp_kernel[(NUM_PROGRAMS, )](logits, + probs, + k, + p, + batch_size, + SIGMA, + vocab_size, + BLOCK_SIZE, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES) return logits From 289c2ba8a462f53b4d21313408bc3aab441a7f35 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 28 Sep 2025 00:58:27 -0700 Subject: [PATCH 35/99] Fixed scratchpads Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 01ad9d23e1c7..763a231b4579 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -248,11 +248,12 @@ def apply_top_k_top_p_triton( SIGMA = 2.15 # Top 0.03 outliers NUM_WARPS = 16 NUM_STAGES = 3 + probs = torch.full((NUM_PROGRAMS, vocab_size), + -float('inf'), + device=logits.device) if k is not None and p is None: - probs = torch.full((NUM_PROGRAMS, vocab_size), - -float('inf'), - device=logits.device) + _topk_kernel[(NUM_PROGRAMS, )](logits, probs, k, @@ -263,8 +264,7 @@ def apply_top_k_top_p_triton( num_warps=NUM_WARPS, num_stages=NUM_STAGES) elif k is None and p is not None: - probs = torch.full_like(logits, -float('inf'), device=logits.device) - probs_2 = torch.full_like(logits, -float('inf'), device=logits.device) + probs_2 = torch.full_like(probs, -float('inf'), device=logits.device) _topp_kernel[(NUM_PROGRAMS, )](logits, probs, probs_2, @@ -276,7 +276,6 @@ def apply_top_k_top_p_triton( num_warps=NUM_WARPS, num_stages=NUM_STAGES) elif k is not None and p is not None: - probs = torch.full_like(logits, -float('inf'), device=logits.device) _topk_topp_kernel[(NUM_PROGRAMS, )](logits, probs, k, From 5b0b1e6e706a34058deb47d5e1aff61c3511b4e2 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 28 Sep 2025 00:59:34 -0700 Subject: [PATCH 36/99] Fixed scratchpads Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 763a231b4579..122fcb8f7664 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -245,7 +245,7 @@ def apply_top_k_top_p_triton( device_prop = torch.cuda.get_device_properties(logits.device) NUM_PROGRAMS = device_prop.multi_processor_count BLOCK_SIZE = 16384 - SIGMA = 2.15 # Top 0.03 outliers + SIGMA = 2.15 # Top 0.03 outliers - Maybe dynamically adjust based on K? NUM_WARPS = 16 NUM_STAGES = 3 probs = torch.full((NUM_PROGRAMS, vocab_size), @@ -253,7 +253,6 @@ def apply_top_k_top_p_triton( device=logits.device) if k is not None and p is None: - _topk_kernel[(NUM_PROGRAMS, )](logits, probs, k, From 5e6156cf275899de34d16e2e64344b4355c68c62 Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Wed, 22 Oct 2025 18:34:10 -0700 Subject: [PATCH 37/99] Init Sunga's correct triton top_k top_p implementation Signed-off-by: Sunga Kim --- compare.py | 410 ++++++++++++++++++++++++ vllm/v1/sample/ops/topk_topp_sampler.py | 363 +++++++++++++++++++++ 2 files changed, 773 insertions(+) create mode 100644 compare.py diff --git a/compare.py b/compare.py new file mode 100644 index 000000000000..2fdc8238935b --- /dev/null +++ b/compare.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from datetime import datetime +from itertools import product + +import regex as re +import torch + +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + apply_top_k_top_p_triton, + apply_top_k_top_p_test, + + ) + + +def g_str(s): + return "\033[32m" + s + "\033[0m" + + +def r_str(s): + return "\033[31m" + s + "\033[0m" + + +def y_str(s): + return "\033[33m" + s + "\033[0m" + + +def b_str(s): + return "\033[34m" + s + "\033[0m" + + +def print_to_log(s, log_file): + print(s) + # Remove the color codes + s = re.sub(r"\033\[[0-9;]*m", "", s) + with open(log_file, "a") as f: + f.write(s + "\n") + + +def test_accuracy(logits, k, p, func_list): + input_logit_list = [logits.clone().detach() for i in range(len(func_list))] + original_logits = func_list[0](input_logit_list[0], k, p) + output_correct_list = [] + for i in range(1, len(func_list)): + output_logits = func_list[i](input_logit_list[i], k, p) + + torch.cuda.synchronize() + is_correct = torch.allclose(original_logits, output_logits) + output_correct_list.append(is_correct) + func_name = func_list[i].__name__ + + if not is_correct: + print_to_log(r_str(f"Error: logits are not close on {i} - " + f"{func_name}"), log_file) + error_mask = torch.abs(output_logits - original_logits) > 1e-5 + error_rows = torch.where(error_mask)[0] + error_rows = torch.unique(error_rows) + num_error_rows = error_rows.shape[0] + error_cols = torch.where(error_mask)[1] + error_cols = torch.unique(error_cols) + num_error_cols = error_cols.shape[0] + print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", + log_file) + print_to_log(f"num_error_cols: {num_error_cols}", log_file) + row_to_show = 5 if num_error_rows > 5 else num_error_rows + logits_to_show = torch.sort(output_logits[error_rows], + descending=True).values + logits_to_show = logits_to_show[:row_to_show, :20] + print_to_log(f"logits: {logits_to_show}", log_file) + original_logits_to_show = \ + torch.sort(original_logits[error_rows], descending=True).values + original_logits_to_show = original_logits_to_show[:row_to_show, :20] + print_to_log(f"original_logits: {original_logits_to_show}", log_file) + + return output_correct_list + + +def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): + # We must clone the logits for each run to avoid modifying the original + warmup_tensor = logits.clone().detach() + for _ in range(num_warmup): + test_func(warmup_tensor, k, p) + torch.cuda.synchronize() + + input_logits = [logits.clone().detach() for _ in range(num_runs)] + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + for _ in range(num_runs): + input_logits[_] = test_func(input_logits[_], k, p) + end.record() + torch.cuda.synchronize() + time_taken = start.elapsed_time(end) / num_runs + + return time_taken + + +if __name__ == "__main__": + date_str = datetime.now().strftime("%Y%m%d_%H%M%S") + + batch_size_list = [64, 128, 1024] + vocab_size_list = [4096, 16384] + p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] + k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + func_list = [apply_top_k_top_p, apply_top_k_top_p_triton, apply_top_k_top_p_triton] + + log_file = f"triton_topk_topp_test_{date_str}.log" + csv_file = f"triton_topk_topp_test_{date_str}.csv" + + print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) + print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) + print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) + print_to_log(y_str("p_list:") + f"{p_list}", log_file) + print_to_log(y_str("k_list:") + f"{k_list}", log_file) + + print_to_log(y_str("log_file:") + f"{log_file}", log_file) + print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) + + with open(csv_file, "w") as f: + f.write("dist_generator,batch_size,vocab_size,p,k,triton_correct,test_correct" + "torch_time_taken,triton_time_taken,test_time_taken,triton_speedup,test_speedup\n") + + for batch_size, vocab_size, p, k in product(batch_size_list, + vocab_size_list, p_list, + k_list): + if p is None and k is None: + continue + + logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 + logits_list = [("RANDN", logits_randn)] + + if p == "RAND": + p_tensor = torch.rand((batch_size, ), device="cuda") * 0.95 + 0.05 + elif p is not None: + p_tensor = torch.full((batch_size, ), p, device="cuda") + else: + p_tensor = None + + if k == "RAND": + k_tensor = torch.randint(1, + vocab_size, (batch_size, ), + device="cuda") + elif k is not None: + k_tensor = torch.full((batch_size, ), k, device="cuda") + else: + k_tensor = None + + for dist_generator, logits in logits_list: + print_to_log(y_str("--------------------------------"), log_file) + print_to_log( + g_str("Testing ") + f"{dist_generator}" + + y_str(" with batch_size: ") + f"{batch_size}" + + y_str(" vocab_size: ") + f"{vocab_size}" + y_str(" p: ") + + f"{p}" + y_str(" k: ") + f"{k}", log_file) + correct_list = test_accuracy(logits, k_tensor, p_tensor, func_list) + for i in range(len(func_list) - 1): + is_correct = correct_list[i] + if not is_correct: + print_to_log( + f"Error: logits are not close for function {func_list[i + 1].__name__}," + f" batch_size: {batch_size}," + f" vocab_size: {vocab_size}, dist_generator: " + f"{dist_generator}, p: {p}, k: {k}", log_file) + time_list = [] + for func in func_list: + time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) + time_list.append(time_taken) + print_to_log( + b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) + print_to_log( + b_str("triton_time_taken: ") + f"{time_list[1]}", + log_file) + print_to_log( + b_str("test_time_taken: ") + f"{time_list[2]}", log_file) + print_to_log( + g_str("Triton Speedup over Torch: ") + + f"{time_list[0] / time_list[1]:.8f}x", log_file) + print_to_log( + y_str("Test Speedup over Torch: ") + + f"{time_list[0] / time_list[2]:.8f}x", log_file) + with open(csv_file, "a") as f: + f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k}," + f"{correct_list[0]},{correct_list[1]},{time_list[0]},{time_list[1]},{time_list[2]}," + f"{time_list[0] / time_list[1]:.8f}, {time_list[0] / time_list[2]:.8f}\n") + print_to_log(y_str("--------------------------------\n"), log_file) + +"""# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from datetime import datetime +from itertools import product +import regex as re +import torch + +print("Torch version:", torch.__version__) +print("CUDA available:", torch.cuda.is_available()) +print("Default device:", torch.cuda.current_device()) + +# --- MODIFIED IMPORTS --- +# We need all the component kernels to time them individually +from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, + apply_top_k_with_pivot_filter, # Used for accuracy check + apply_top_k_only, # This is the baseline AND our Kernel 2 + top_k_pivot_and_sort, # Kernel 1 + scatter_topk_kernel # Kernel 3 +) +print("All kernels imported successfully") + +x = torch.randn(2, 5, device="cuda") +y = apply_top_k_only(x, k=torch.tensor([2,2], device="cuda")) +print("apply_top_k_only ran successfully, output:", y) + + +def g_str(s): return "\033[32m" + s + "\033[0m" +def r_str(s): return "\033[31m" + s + "\033[0m" +def y_str(s): return "\033[33m" + s + "\033[0m" +def b_str(s): return "\033[34m" + s + "\033[0m" + +def print_to_log(s, log_file): + print(s) + s = re.sub(r"\033[[0-9;]*m", "", s) + with open(log_file, "a") as f: + f.write(s + "\n") + +# --- UNCHANGED --- +# test_accuracy still runs the *full* pipeline to check for correctness +def test_accuracy(logits, k, log_file): + input_logits_torch = logits.clone().detach() + input_logits_triton = logits.clone().detach() + + original_logits = apply_top_k_only(input_logits_torch, k) + triton_pivot_logits = apply_top_k_with_pivot_filter(input_logits_triton, k) + + torch.cuda.synchronize() + is_correct = torch.allclose(original_logits, triton_pivot_logits) + + if not is_correct: + print_to_log(r_str("Error: logits are not close"), log_file) + + return is_correct + +# --- REWRITTEN test_time FUNCTION --- +def test_time(logits, k, num_runs=30, num_warmup=5): + + batch_size, vocab_size = logits.shape + + # --- Warmup --- + for _ in range(num_warmup): + warmup_tensor_torch = logits.clone().detach() + apply_top_k_only(warmup_tensor_torch, k) + + warmup_tensor_triton = logits.clone().detach() + apply_top_k_with_pivot_filter(warmup_tensor_triton, k) + torch.cuda.synchronize() + + # --- 1. Baseline `apply_top_k_only` timing --- + start_torch = torch.cuda.Event(enable_timing=True) + end_torch = torch.cuda.Event(enable_timing=True) + + start_torch.record() + for i in range(num_runs): + input_tensor = logits.clone().detach() + apply_top_k_only(input_tensor, k) + end_torch.record() + torch.cuda.synchronize() + apply_top_k_time = start_torch.elapsed_time(end_torch) / num_runs + + # --- 2. Triton Kernel 2 (Sort) Timing --- + + # Events for Kernel 2 + start_k2 = torch.cuda.Event(enable_timing=True) + end_k2 = torch.cuda.Event(enable_timing=True) + + # Kernel 2 time accumulator + triton_k2_time_acc = 0.0 + + for i in range(num_runs): + if (k == vocab_size).all(): + continue + + # 1. Setup + input_tensor = logits.clone().detach() + probs = torch.full_like(input_tensor, -float('inf')) + l = torch.empty((batch_size,), device=input_tensor.device, dtype=torch.int32) + idx_tensor = torch.full_like(input_tensor, -1, dtype=torch.int) + + BLOCK_SIZE = 1024 + SIGMA = 2.0 + grid_pivot = (batch_size,) + + # 2. Run Kernel 1 (Pivot) - *No timer* + top_k_pivot_and_sort[grid_pivot]( + input_tensor, probs, l, idx_tensor, k, batch_size, + SIGMA=SIGMA, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, + ) + + torch.cuda.synchronize() + max_l = torch.max(l).item() + outliers = probs[:, :max_l] + outliers_idx = idx_tensor[:, :max_l] + k_pinned = torch.minimum(k, l) + + # 4. Time Kernel 2 (Sort) + start_k2.record() + apply_top_k_only(outliers, k_pinned) + end_k2.record() + + torch.cuda.synchronize() + triton_k2_time_acc += start_k2.elapsed_time(end_k2) + + triton_sort_only_time = triton_k2_time_acc / num_runs + + return apply_top_k_time, triton_sort_only_time + + +def main(): + print("Starting compare.py...") + date_str = datetime.now().strftime("%Y%m%d_%H%M%S") + + #batch_size_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] # Up to 512 + #vocab_size_list = [4096, 16384, 65536, 262144, 102400] + #k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + + batch_size_list = [1, 2, 4, 8] + vocab_size_list = [4096, 16384] + k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + + + log_file = f"triton_topk_topp_test_{date_str}.log" + csv_file = f"triton_topk_topp_test_{date_str}.csv" + + print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) + print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) + print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) + print_to_log(y_str("k_list:") + f"{k_list}", log_file) + print_to_log(y_str("log_file:") + f"{log_file}", log_file) + print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) + + # --- MODIFIED CSV HEADER --- + with open(csv_file, "w") as f: + f.write("dist_generator,batch_size,vocab_size,k,is_correct," + "apply_top_k_time,triton_sort_only_time,speedup_vs_baseline\n") + + for batch_size, vocab_size, k in product(batch_size_list, + vocab_size_list, + k_list): + + logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 + logits_list = [("RANDN", logits_randn)] + + if k == "RAND": + k_tensor = torch.randint(1, + vocab_size, (batch_size,), + device="cuda") + elif k is not None: + k_val = min(k, vocab_size) # Ensure k is not > vocab_size + k_tensor = torch.full((batch_size,), k_val, device="cuda") + else: + k_tensor = torch.full((batch_size,), vocab_size, device="cuda") + + for dist_generator, logits in logits_list: + print_to_log(y_str("--------------------------------"), log_file) + print_to_log( + g_str("Testing ") + f"{dist_generator}" + + y_str(" with batch_size: ") + f"{batch_size}" + + y_str(" vocab_size: ") + f"{vocab_size}" + + y_str(" k: ") + f"{k}", log_file) + + is_correct = test_accuracy(logits, k_tensor, log_file) + if not is_correct: + print_to_log( + r_str(f"Error: logits are not close for batch_size: {batch_size}, " + f"vocab_size: {vocab_size}, dist_generator: {dist_generator}, k: {k}"), + log_file) + + # --- MODIFIED TIMING CALL --- + apply_top_k_time, triton_sort_only_time = test_time(logits, k_tensor) + + print_to_log( + b_str("apply_top_k_time (Baseline): ") + f"{apply_top_k_time}", log_file) + print_to_log( + b_str("triton_sort_only_time (Kernel 2): ") + f"{triton_sort_only_time}", + log_file) + + # --- THIS IS THE FIX --- + # Handle the k: None case where triton_sort_only_time is 0.0 + if triton_sort_only_time > 0: + speedup = apply_top_k_time / triton_sort_only_time + speedup_str = f"{speedup:.8f}x" + else: + # 'k: None' case, speedup is not applicable (N/A) + speedup = 0.0 + speedup_str = "N/A (passthrough)" + # --- END FIX --- + + print_to_log( + g_str("Triton Sort Speedup vs. Full Baseline: ") + + speedup_str, log_file) + + # Write to CSV + with open(csv_file, "a") as f: + f.write(f"{dist_generator},{batch_size},{vocab_size},{k}," + f"{is_correct},{apply_top_k_time},{triton_sort_only_time}," + f"{speedup:.8f}\n") # Still write the float for CSV + print_to_log(y_str("--------------------------------\n"), log_file) + +if __name__ == "__main__": + main()""" \ No newline at end of file diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 868a33e6d10d..da84f115762d 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -7,6 +7,7 @@ import triton import triton.language as tl from packaging import version +from typing import Optional from vllm import envs from vllm.config.model import LogprobsMode @@ -971,3 +972,365 @@ def flashinfer_sample( ) return next_token_ids.view(-1) + +@triton.jit +def _topp_kernel_sorted( + LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, + N: tl.constexpr, BLOCK_SIZE: tl.constexpr +): + """Modified top-p kernel with sort-equivalent tie-breaking + and re-enabled outlier optimization. + """ + NUM_TILES: tl.constexpr = (N + BLOCK_SIZE - 1) // BLOCK_SIZE + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + for row_id in tl.range(pid, B, num_programs): + p = tl.load(P + row_id) + if p != 1.0: # All tokens are valid + + p_pivot = -float('inf') + + LOGITS_ROW = LOGITS + row_id * N + PROBS_ROW = PROBS + pid * N + PROBS_2_ROW = PROBS_2 + pid * N # <-- RE-ADDED + + # Default search params + search_addr = PROBS_ROW + search_range = N + search_iters = NUM_TILES + + max_logit = -float('inf') + min_logit = float('inf') + + force_remove_logit = -float('inf') + num_force_remove = tl.zeros((), dtype=tl.uint32) + + # --- ZEROTH PASS (RE-ADDED) --- + # Compute *exact* avg and std + sum_logits = 0.0 + sum_sq_logits = 0.0 + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=0.0) + sum_logits += tl.sum(tl.where(mask_n, logits_blk, 0.0)) + sum_sq_logits += tl.sum(tl.where(mask_n, logits_blk * logits_blk, 0.0)) + + avg_logit = sum_logits / N + sq_avg_logit = sum_sq_logits / N + std_logit = tl.sqrt(tl.maximum(0.0, sq_avg_logit - avg_logit * avg_logit)) + outlier_pivot = avg_logit + SIGMA * std_logit # <-- RE-ADDED + num_outliers = tl.zeros((), dtype=tl.uint32) # <-- RE-ADDED + sum_outlier_probs = 0.0 # <-- RE-ADDED + + sum_exp_logits = 0.0 + + # First pass: compute max and min logits + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) # Use -inf + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + # Second pass: Calculate exp logits and sum + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + + probs_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + + # --- OUTLIER_PROB (RE-ADDED) --- + outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits + + # Third pass: Calculate final probs AND get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + + # --- OUTLIER MASKING LOGIC (RE-ADDED) --- + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) + + + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + if sum_outlier_probs > p: + min_range = outlier_prob + search_addr = PROBS_2_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + + second_max_logit = -float('inf') + num_iters = 0 + p_pivots_sum_0 = 0.0 # --> total prob including all equivalent min + min_larger_0 = 1.0 # --> prob of tie-breaking min + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + # Binary search for p_pivot + while p_pivot == -float('inf') and num_iters < 32: + p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range + p_pivots_sum_0 = 0.0 + + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=0.0) + + masked_larger_0 = tl.where(probs_blk > p_pivot_0, + probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, + tl.min(masked_larger_0)) + + p_pivots_sum_0 += tl.sum(probs_blk * + (probs_blk > p_pivot_0)) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=0.0) + + num_min_larger_0 += tl.sum( + tl.abs(probs_blk - min_larger_0) < 1e-7) + + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + else: + min_range = p_pivot_0 + else: + max_range = p_pivot_0 + + num_iters += 1 + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: + p_pivot = p_pivot_0 + + if p_pivot >= max_logit: + p_pivot = second_max_logit + elif num_min_larger_0 > 1: + num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, + tl.uint32) # --> number of probs to be removed + force_remove_logit = tl.log( + min_larger_0 * sum_exp_logits) + max_logit + + p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + + # Apply mask with (non-sort-equivalent) tie-breaking + current_num_removed = tl.zeros((), dtype=tl.uint32) + if p_pivot != -float('inf'): + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < N + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) + + if force_remove_logit != -float('inf'): + # Match PyTorch's non-sort-equivalent tie-breaking + tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) + is_tie = tl.abs(logits_blk - force_remove_logit) < tolerance + tie_position = tl.cumsum(is_tie) - 1 + current_num_removed + should_remove = is_tie & (tie_position < num_force_remove) + logits_blk = tl.where(should_remove, -float('inf'), logits_blk) + current_num_removed += tl.sum(is_tie) + + # Standard threshold masking + tolerance = 1e-6 * tl.maximum(1.0, tl.abs(p_pivot)) + logits_blk = tl.where(logits_blk >= (p_pivot - tolerance), logits_blk, + -float('inf')) + + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + +def apply_top_p_sorted_equivalent( + logits: torch.Tensor, + p: torch.Tensor, + sigma: float = 3.0, +) -> torch.Tensor: + """Apply top-p using binary search (no sort!) with sort-equivalent results. + + Args: + logits: [B, N] logits tensor + p: [B] top-p thresholds + sigma: Standard deviation multiplier for outlier detection + Returns: + Modified logits, equivalent to sorted top-p version + """ + B, N = logits.shape + device = logits.device + + BLOCK_SIZE = triton.next_power_of_2(min(N, 1024)) + num_warps = 4 if BLOCK_SIZE < 2048 else 8 + + probs = torch.empty((B, N), device=device, dtype=torch.float32) + probs_2 = torch.empty((B, N), device=device, dtype=torch.float32) + + grid = (B,) + _topp_kernel_sorted[grid]( + logits, + probs, + probs_2, + p, + B, + SIGMA=sigma, + N=N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return logits + +def apply_top_k_top_p_test( + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, +) -> torch.Tensor: + """Optimized implementation combining torch.topk and binary search kernel. + """ + if p is None: + if k is None: + return logits + return apply_top_k_only(logits, k) + # Apply top-k filter first if needed + if k is not None: + logits = apply_top_k_only(logits, k) + + # Apply top-p using binary search (no sort!) + return apply_top_p_sorted_equivalent(logits, p) + +"""@triton.jit +def top_p_filter_triton(LOGITS, PROBS, l, idx_tensor, K, B, SIGMA:tl.constexpr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr): + + #Agressively filters logits using pivot-based approach before top-k, in order to minimize the amount of sorting required for top k + + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + + for row_id in tl.range(pid, B, num_programs): + k = tl.load(K + row_id) + if k != VOCAB_SIZE: # All tokens are valid + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + OUT_ROW = OUT + row_id * VOCAB_SIZE + IDX_ROW = IDX + row_id * VOCAB_SIZE + + sum_logits = 0.0 + sum_sq_logits = 0.0 + + for i in range(NUM_TILES): + offs = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < VOCAB_SIZE + vals = tl.load(LOGITS_ROW + offs, mask=mask, other=0.0) + sum_logits += tl.sum(vals, where=mask) + sum_sq_logits += tl.sum(vals * vals, where=mask) + + mean = sum_logits / VOCAB_SIZE + var = sum_sq_logits / VOCAB_SIZE - mean * mean + std = tl.sqrt(tl.maximum(var, 0.0)) + threshold = mean + SIGMA * std + + count = 0 + for i in range(NUM_TILES): + offs = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < VOCAB_SIZE + vals = tl.load(LOGITS_ROW + offs, mask=mask, other=0.0) + keep_mask = vals > threshold + + # Write filtered logits + out_vals = tl.where(keep_mask, vals, -float("inf")) + tl.store(OUT_ROW + offs, out_vals, mask=mask) + + # Write kept indices contiguously + new_idx = tl.where(keep_mask, offs + i * BLOCK_SIZE, -1) + kept_idx = new_idx[keep_mask] + num_kept = tl.sum(keep_mask, where=mask) + + # store valid indices sequentially + if num_kept > 0: + write_offs = count + tl.arange(0, num_kept) + tl.store(IDX_ROW + write_offs, kept_idx) + count += num_kept + + # Record number of kept logits + tl.store(L + row_id, count) + + +def apply_top_p_filtered( + logits: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + + # Applies top-p using filtering + + batch_size, vocab_size = logits.shape + + probs = torch.empty_like(logits) + l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) + idx_tensor = torch.empty_like(logits, dtype=torch.int) + + BLOCK_SIZE = 1024 + SIGMA = 2.0 + + grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) + top_p_filter_triton[grid]( + logits, + probs, + l, + idx_tensor, + k, + batch_size, + SIGMA=SIGMA, + VOCAB_SIZE=vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + max_l = torch.max(l).item() + filtered_logits = probs[:, :max_l] + logits = apply_top_k_only(logits, k) + return logits + +def apply_top_k_top_p_test2( + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, +) -> torch.Tensor: + + # Filter out the outliers + + if p is None: + if k is None: + return logits + return apply_top_k_only(logits, k) + # Apply top-k filter first if needed + if k is not None: + logits = apply_top_k_only(logits, k) + + # Apply top-p using binary search (no sort!) + logits = apply_top_p_filtered(logits, p)""" From 7401ead12aaa36708a7e1fafc7ed9dde63fa4ab4 Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Thu, 23 Oct 2025 12:19:36 -0700 Subject: [PATCH 38/99] initial commit --- compare.py | 6 +- vllm/v1/sample/ops/test1.py | 15 + vllm/v1/sample/ops/topk_topp_sampler.py | 570 +++++++++++++++++------- 3 files changed, 423 insertions(+), 168 deletions(-) create mode 100644 vllm/v1/sample/ops/test1.py diff --git a/compare.py b/compare.py index 2fdc8238935b..04f5e5676ea2 100644 --- a/compare.py +++ b/compare.py @@ -8,8 +8,8 @@ from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, apply_top_k_top_p_triton, - apply_top_k_top_p_test, - + apply_top_k_top_p_test2 + ) @@ -103,7 +103,7 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): vocab_size_list = [4096, 16384] p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] - func_list = [apply_top_k_top_p, apply_top_k_top_p_triton, apply_top_k_top_p_triton] + func_list = [apply_top_k_top_p, apply_top_k_top_p_triton, apply_top_k_top_p_test2] log_file = f"triton_topk_topp_test_{date_str}.log" csv_file = f"triton_topk_topp_test_{date_str}.csv" diff --git a/vllm/v1/sample/ops/test1.py b/vllm/v1/sample/ops/test1.py new file mode 100644 index 000000000000..fb78fb676159 --- /dev/null +++ b/vllm/v1/sample/ops/test1.py @@ -0,0 +1,15 @@ +import torch + +# Create a randomly initialized 5x3 tensor +x = torch.rand(5, 3) +print("Random Tensor:\n", x) + +# Check if CUDA is available and print the result +cuda_available = torch.cuda.is_available() +print("\nCUDA available:", cuda_available) + +# If CUDA is available, you can also try moving a tensor to the GPU +if cuda_available: + device = torch.device("cuda") + y = torch.ones(2, 2, device=device) + print("\nTensor on GPU:\n", y) \ No newline at end of file diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index da84f115762d..d74d1bc30e5a 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -49,18 +49,18 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: "Falling back to default sampling implementation." ) self.forward = self.forward_native - elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: - # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for - # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by + elif envs.VLLM_USE_FLASHIVOCAB_SIZEFER_SAMPLER is not False: + # VOCAB_SIZEOTE(woosuk): The V0 sampler doesn't use FlashInfer for + # sampling unless VLLM_USE_FLASHIVOCAB_SIZEFER_SAMPLER=1 (i.e., by # default it is unused). For backward compatibility, we set - # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and + # `VLLM_USE_FLASHIVOCAB_SIZEFER_SAMPLER` as None by default and # interpret it differently in V0 and V1 samplers: In V0, # None means False, while in V1, None means True. This is # why we use the condition - # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. + # `envs.VLLM_USE_FLASHIVOCAB_SIZEFER_SAMPLER is not False` here. logger.info_once("Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda - elif envs.VLLM_USE_TRITON_SAMPLER is not False: + elif envs.VLLM_USE_TRITOVOCAB_SIZE_SAMPLER is not False: logger.info_once( "Using Triton for top-p & top-k sampling.") self.forward = self.forward_triton @@ -69,11 +69,11 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: "FlashInfer is available, but it is not enabled. " "Falling back to the PyTorch-native implementation of " "top-p & top-k sampling. For the best performance, " - "please set VLLM_USE_FLASHINFER_SAMPLER=1." + "please set VLLM_USE_FLASHIVOCAB_SIZEFER_SAMPLER=1." ) self.forward = self.forward_native else: - if envs.VLLM_USE_TRITON_SAMPLER is not False: + if envs.VLLM_USE_TRITOVOCAB_SIZE_SAMPLER is not False: logger.info_once( "Using Triton for top-p & top-k sampling.") self.forward = self.forward_triton @@ -179,7 +179,7 @@ def forward_cpu( elif self.logprobs_mode == "processed_logprobs": logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) - # Note: this is a workaround for + # VOCAB_SIZEote: this is a workaround for # https://github.com/pytorch/pytorch/pull/151218 @torch.compile(dynamic=True) def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: @@ -252,28 +252,28 @@ def apply_top_k_top_p_triton( batch_size, vocab_size = logits.shape device_prop = torch.cuda.get_device_properties(logits.device) - NUM_PROGRAMS = device_prop.multi_processor_count + VOCAB_SIZEUM_PROGRAMS = device_prop.multi_processor_count BLOCK_SIZE = 16384 SIGMA = 2.15 # Top 0.03 outliers - Maybe dynamically adjust based on K? - NUM_WARPS = 16 - NUM_STAGES = 3 - probs = torch.full((NUM_PROGRAMS, vocab_size), + VOCAB_SIZEUM_WARPS = 16 + VOCAB_SIZEUM_STAGES = 3 + probs = torch.full((VOCAB_SIZEUM_PROGRAMS, vocab_size), -float('inf'), device=logits.device) if k is not None and p is None: - _topk_kernel[(NUM_PROGRAMS, )](logits, + _topk_kernel[(VOCAB_SIZEUM_PROGRAMS, )](logits, probs, k, batch_size, SIGMA, vocab_size, BLOCK_SIZE, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES) + num_warps=VOCAB_SIZEUM_WARPS, + num_stages=VOCAB_SIZEUM_STAGES) elif k is None and p is not None: probs_2 = torch.full_like(probs, -float('inf'), device=logits.device) - _topp_kernel[(NUM_PROGRAMS, )](logits, + _topp_kernel[(VOCAB_SIZEUM_PROGRAMS, )](logits, probs, probs_2, p, @@ -281,10 +281,10 @@ def apply_top_k_top_p_triton( SIGMA, vocab_size, BLOCK_SIZE, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES) + num_warps=VOCAB_SIZEUM_WARPS, + num_stages=VOCAB_SIZEUM_STAGES) elif k is not None and p is not None: - _topk_topp_kernel[(NUM_PROGRAMS, )](logits, + _topk_topp_kernel[(VOCAB_SIZEUM_PROGRAMS, )](logits, probs, k, p, @@ -292,48 +292,48 @@ def apply_top_k_top_p_triton( SIGMA, vocab_size, BLOCK_SIZE, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES) + num_warps=VOCAB_SIZEUM_WARPS, + num_stages=VOCAB_SIZEUM_STAGES) return logits @triton.jit -def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, +def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) num_programs = tl.num_programs(0) - NUM_TILES: tl.constexpr = (N + BLOCK_SIZE - 1) // BLOCK_SIZE + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE for row_id in tl.range(pid, B, num_programs): k = tl.load(K + row_id) - if k != N: # All tokens are valid + if k != VOCAB_SIZE: # All tokens are valid - # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K - # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, - # WHICH MAY RETURN MORE THAN K LOGITS, - # FOLLOWING THE IMPLEMENTATION in apply_top_k_only(). - # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P - # IMPLEMENTATION AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT - # USING THE FORCE_REMOVE_LOGIT VARIABLE + # THERE IS VOCAB_SIZEO DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET FOR THIS TOP-K + # CURREVOCAB_SIZET IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE IVOCAB_SIZECLUDES ALL DUPLICATE LOGITS, + # WHICH MAY RETURVOCAB_SIZE MORE THAVOCAB_SIZE K LOGITS, + # FOLLOWIVOCAB_SIZEG THE IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE in apply_top_k_only(). + # IF YOU VOCAB_SIZEEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P + # IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE AVOCAB_SIZED IMPLEMEVOCAB_SIZET THE DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET + # USIVOCAB_SIZEG THE FORCE_REMOVE_LOGIT VARIABLE k_pivot = -float('inf') - LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + pid * N + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + PROBS_ROW = PROBS + pid * VOCAB_SIZE search_addr = LOGITS_ROW - search_range = N + search_range = VOCAB_SIZE search_iters = NUM_TILES max_logit = -float('inf') min_logit = float('inf') # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if N < BLOCK_SIZE + # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < N + mask_n = offs < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / N - sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + avg_logit = tl.sum(logits_blk) / VOCAB_SIZE + sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) outlier_pivot = avg_logit + SIGMA * std_logit @@ -414,9 +414,9 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, # Third pass: Apply top-k mask if k_pivot != -float('inf'): - for i in range(0, NUM_TILES): + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N + mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) mask = (logits_blk > k_pivot) logits_blk = tl.where(mask, logits_blk, -float('inf')) @@ -425,8 +425,8 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, N: tl.constexpr, @triton.jit def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, - N: tl.constexpr, BLOCK_SIZE: tl.constexpr): - NUM_TILES: tl.constexpr = (N + BLOCK_SIZE - 1) // BLOCK_SIZE + VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr): + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): @@ -435,12 +435,12 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, p_pivot = -float('inf') - LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + pid * N - PROBS_2_ROW = PROBS_2 + pid * N + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + PROBS_ROW = PROBS + pid * VOCAB_SIZE + PROBS_2_ROW = PROBS_2 + pid * VOCAB_SIZE search_addr = PROBS_ROW - search_range = N + search_range = VOCAB_SIZE search_iters = NUM_TILES max_logit = -float('inf') @@ -452,13 +452,13 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, num_force_remove = tl.zeros((), dtype=tl.uint32) # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if N < BLOCK_SIZE + # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE # OR all logits are the same offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < N + mask_n = offs < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / N - sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + avg_logit = tl.sum(logits_blk) / VOCAB_SIZE + sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) outlier_pivot = avg_logit + SIGMA * std_logit @@ -591,7 +591,7 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, if p_pivot != -float('inf'): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N + mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) @@ -619,19 +619,19 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, @triton.jit def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, - N: tl.constexpr, BLOCK_SIZE: tl.constexpr): - NUM_TILES: tl.constexpr = (N + BLOCK_SIZE - 1) // BLOCK_SIZE + VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr): + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): k_pivot = -float('inf') p_pivot = -float('inf') - LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + pid * N + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + PROBS_ROW = PROBS + pid * VOCAB_SIZE search_addr = LOGITS_ROW - search_range = N + search_range = VOCAB_SIZE search_iters = NUM_TILES max_logit = -float('inf') @@ -644,12 +644,12 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, num_force_remove = tl.zeros((), dtype=tl.uint32) # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if N < BLOCK_SIZE + # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < N + mask_n = offs < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / N - sq_avg_logit = tl.sum(logits_blk * logits_blk) / N + avg_logit = tl.sum(logits_blk) / VOCAB_SIZE + sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) outlier_pivot = avg_logit + SIGMA * std_logit @@ -684,15 +684,15 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, search_iters = tl.cast( (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - if k != N: # All tokens are valid + if k != VOCAB_SIZE: # All tokens are valid - # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K - # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, - # WHICH MAY RETURN MORE THAN K LOGITS, - # FOLLOWING THE IMPLEMENTATION in apply_top_k_only(). - # IF YOU NEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P - # IMPLEMENTATION AND IMPLEMENT THE DUPLICATE LOGIT MANAGEMENT - # USING THE FORCE_REMOVE_LOGIT VARIABLE. + # THERE IS VOCAB_SIZEO DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET FOR THIS TOP-K + # CURREVOCAB_SIZET IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE IVOCAB_SIZECLUDES ALL DUPLICATE LOGITS, + # WHICH MAY RETURVOCAB_SIZE MORE THAVOCAB_SIZE K LOGITS, + # FOLLOWIVOCAB_SIZEG THE IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE in apply_top_k_only(). + # IF YOU VOCAB_SIZEEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P + # IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE AVOCAB_SIZED IMPLEMEVOCAB_SIZET THE DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET + # USIVOCAB_SIZEG THE FORCE_REMOVE_LOGIT VARIABLE. # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 @@ -740,7 +740,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-8: k_pivot = k_pivot_0 - ############### END OF TOP-K CODE ############### + ############### EVOCAB_SIZED OF TOP-K CODE ############### ############### START OF TOP-P CODE ############### @@ -849,7 +849,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - ############### END OF TOP-P CODE ############### + ############### EVOCAB_SIZED OF TOP-P CODE ############### # Sixth pass: Apply mask pivot = tl.maximum(k_pivot, p_pivot) @@ -857,7 +857,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, if pivot != -float('inf'): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N + mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) @@ -894,6 +894,14 @@ def apply_top_k_only( The logits tensor may be updated in-place. """ + max_top_k = k.max().item() + + # --- FIX: Handle k=0 edge case --- + # If the max k is 0, all rows are 0. Mask everything and exit. + if max_top_k == 0: + logits.fill_(-float("inf")) + return logits + no_top_k_mask = k == logits.shape[1] # Set non-top-k rows to 1 so that we can gather. k = k.masked_fill(no_top_k_mask, 1) @@ -918,7 +926,7 @@ def random_sample( causes CPU-GPU synchronization. """ q = torch.empty_like(probs) - # NOTE(woosuk): To batch-process the requests without their own seeds, + # VOCAB_SIZEOTE(woosuk): To batch-process the requests without their own seeds, # which is the common case, we first assume that every request does # not have its own seed. Then, we overwrite the values for the requests # that have their own seeds. @@ -944,11 +952,11 @@ def flashinfer_sample( However, this function is faster because it avoids sorting the logits tensor via rejection sampling. - NOTE: The outputs of this function do not necessarily match the outputs of + VOCAB_SIZEOTE: The outputs of this function do not necessarily match the outputs of the `random_sample` function. It only guarantees that the outputs are statistically equivalent. - NOTE: This function includes CPU-GPU synchronization, while `random_sample` + VOCAB_SIZEOTE: This function includes CPU-GPU synchronization, while `random_sample` does not. Call this function at the end of the forward pass to minimize the synchronization overhead. """ @@ -976,12 +984,12 @@ def flashinfer_sample( @triton.jit def _topp_kernel_sorted( LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, - N: tl.constexpr, BLOCK_SIZE: tl.constexpr + VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr ): """Modified top-p kernel with sort-equivalent tie-breaking and re-enabled outlier optimization. """ - NUM_TILES: tl.constexpr = (N + BLOCK_SIZE - 1) // BLOCK_SIZE + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) @@ -991,13 +999,13 @@ def _topp_kernel_sorted( p_pivot = -float('inf') - LOGITS_ROW = LOGITS + row_id * N - PROBS_ROW = PROBS + pid * N - PROBS_2_ROW = PROBS_2 + pid * N # <-- RE-ADDED + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + PROBS_ROW = PROBS + pid * VOCAB_SIZE + PROBS_2_ROW = PROBS_2 + pid * VOCAB_SIZE # <-- RE-ADDED # Default search params search_addr = PROBS_ROW - search_range = N + search_range = VOCAB_SIZE search_iters = NUM_TILES max_logit = -float('inf') @@ -1012,13 +1020,13 @@ def _topp_kernel_sorted( sum_sq_logits = 0.0 for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N + mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=0.0) sum_logits += tl.sum(tl.where(mask_n, logits_blk, 0.0)) sum_sq_logits += tl.sum(tl.where(mask_n, logits_blk * logits_blk, 0.0)) - avg_logit = sum_logits / N - sq_avg_logit = sum_sq_logits / N + avg_logit = sum_logits / VOCAB_SIZE + sq_avg_logit = sum_sq_logits / VOCAB_SIZE std_logit = tl.sqrt(tl.maximum(0.0, sq_avg_logit - avg_logit * avg_logit)) outlier_pivot = avg_logit + SIGMA * std_logit # <-- RE-ADDED num_outliers = tl.zeros((), dtype=tl.uint32) # <-- RE-ADDED @@ -1052,7 +1060,7 @@ def _topp_kernel_sorted( # --- OUTLIER_PROB (RE-ADDED) --- outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits - # Third pass: Calculate final probs AND get outliers + # Third pass: Calculate final probs AVOCAB_SIZED get outliers for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range @@ -1061,7 +1069,7 @@ def _topp_kernel_sorted( probs_blk = probs_blk / sum_exp_logits tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - # --- OUTLIER MASKING LOGIC (RE-ADDED) --- + # --- OUTLIER MASKIVOCAB_SIZEG LOGIC (RE-ADDED) --- outlier_mask = (probs_blk > outlier_prob) & mask_n sum_outlier_probs += tl.sum(outlier_mask * probs_blk) num_blk_outliers = tl.sum(outlier_mask) @@ -1148,7 +1156,7 @@ def _topp_kernel_sorted( if p_pivot != -float('inf'): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < N + mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) @@ -1177,20 +1185,20 @@ def apply_top_p_sorted_equivalent( """Apply top-p using binary search (no sort!) with sort-equivalent results. Args: - logits: [B, N] logits tensor + logits: [B, VOCAB_SIZE] logits tensor p: [B] top-p thresholds sigma: Standard deviation multiplier for outlier detection Returns: Modified logits, equivalent to sorted top-p version """ - B, N = logits.shape + B, VOCAB_SIZE = logits.shape device = logits.device - BLOCK_SIZE = triton.next_power_of_2(min(N, 1024)) + BLOCK_SIZE = triton.next_power_of_2(min(VOCAB_SIZE, 1024)) num_warps = 4 if BLOCK_SIZE < 2048 else 8 - probs = torch.empty((B, N), device=device, dtype=torch.float32) - probs_2 = torch.empty((B, N), device=device, dtype=torch.float32) + probs = torch.empty((B, VOCAB_SIZE), device=device, dtype=torch.float32) + probs_2 = torch.empty((B, VOCAB_SIZE), device=device, dtype=torch.float32) grid = (B,) _topp_kernel_sorted[grid]( @@ -1200,7 +1208,7 @@ def apply_top_p_sorted_equivalent( p, B, SIGMA=sigma, - N=N, + VOCAB_SIZE=VOCAB_SIZE, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) @@ -1225,112 +1233,344 @@ def apply_top_k_top_p_test( # Apply top-p using binary search (no sort!) return apply_top_p_sorted_equivalent(logits, p) -"""@triton.jit -def top_p_filter_triton(LOGITS, PROBS, l, idx_tensor, K, B, SIGMA:tl.constexpr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr): - - #Agressively filters logits using pivot-based approach before top-k, in order to minimize the amount of sorting required for top k - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +@triton.jit +def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr, WIDEN_NUM: tl.constexpr): + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) - NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE - for row_id in tl.range(pid, B, num_programs): - k = tl.load(K + row_id) - if k != VOCAB_SIZE: # All tokens are valid + for row_id in tl.range(pid, B, num_programs): + p = tl.load(P + row_id) # fetches the p value of the row it is working on + if p != 1.0: # if p == 1, this becomes pointless ! + p_pivot = -float('inf') + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - OUT_ROW = OUT + row_id * VOCAB_SIZE - IDX_ROW = IDX + row_id * VOCAB_SIZE + PROBS_ROW = PROBS + row_id * VOCAB_SIZE + PROBS_2_ROW = PROBS_2 + row_id * VOCAB_SIZE + IDX_ROW = idx_tensor + row_id * VOCAB_SIZE - sum_logits = 0.0 - sum_sq_logits = 0.0 + search_address = PROBS_ROW + search_range = VOCAB_SIZE + search_iters = NUM_TILES + + max_logit = -float('inf') + min_logit = float('inf') + + force_remove_logit = -float('inf') # for handling duplicate cases (edge case) + num_force_remove = tl.zeros((), dtype=tl.uint32) + + # First Pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / VOCAB_SIZE + sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + + outlier_pivot = avg_logit + SIGMA * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) + sum_outlier_probs = 0.0 + sum_exp_logits = 0.0 + + # ====== Second Pass: compute max and min logits ====== + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=avg_logit) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + # ====== Third pass: Calculate exp logits and sum ====== + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + + probs_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits + + # ====== Fourth pass: Calculate probs and get outliers ====== + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) # stores the final probs after masking to PROBS_2 + + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + if sum_outlier_probs > p: + min_range = outlier_prob + search_addr = PROBS_2_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + + second_max_logit = -float('inf') + + num_iters = 0 + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + # ====== Fifth Passes: Search for p_pivot(2log_2(n)) ====== + while p_pivot == -float('inf') and num_iters < 32: + p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range + p_pivots_sum_0 = 0.0 + + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(search_address + offs_n, + mask=mask_n, + other=0.0) + + masked_larger_0 = tl.where(probs_blk > p_pivot_0, + probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, + tl.min(masked_larger_0)) + + p_pivots_sum_0 += tl.sum(probs_blk * + (probs_blk > p_pivot_0)) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(search_address + offs_n, + mask=mask_n, + other=0.0) + + num_min_larger_0 += tl.sum( + tl.abs(probs_blk - min_larger_0) < 1e-7) + + # Check if any of the pivots are equal to k + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + else: + min_range = p_pivot_0 + else: + max_range = p_pivot_0 + + num_iters += 1 + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: + p_pivot = p_pivot_0 + + # At least one value should be greater than p_pivot + if p_pivot >= max_logit: + p_pivot = second_max_logit + elif num_min_larger_0 > 1: + # Force remove duplicates (p_pivot is made to include all + # duplicates if it falls on the duplicates) + num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, + tl.uint32) + force_remove_logit = tl.log( + min_larger_0 * sum_exp_logits) + max_logit + + p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + p_pivot_logit = -float('inf') + + # -------- widen cutoff by 10% ---------------- + if p_pivot != -float('inf'): + # WIDEN_NUM (e.g., 90) / 100.0 = 0.9 + widened_prob = p_pivot * (WIDEN_NUM / 100.0) + # clamp widened_prob to <= max possible prob + widened_prob = tl.minimum(widened_prob, max_range) + p_pivot_logit = tl.log(widened_prob * sum_exp_logits) + max_logit + + current_num_force_remove = tl.zeros((), dtype=tl.uint32) + kept_write_pos = 0 + + if p_pivot != -float('inf'): + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) + + if force_remove_logit != -float('inf'): + # Force remove duplicates + tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) + force_remove_mask = tl.abs( + logits_blk - force_remove_logit) < tolerance + force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove + force_remove_count_mask = force_remove_count <= num_force_remove + force_remove_mask = force_remove_count_mask & force_remove_mask + logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) + current_num_force_remove = tl.max(force_remove_count) - for i in range(NUM_TILES): - offs = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < VOCAB_SIZE - vals = tl.load(LOGITS_ROW + offs, mask=mask, other=0.0) - sum_logits += tl.sum(vals, where=mask) - sum_sq_logits += tl.sum(vals * vals, where=mask) - - mean = sum_logits / VOCAB_SIZE - var = sum_sq_logits / VOCAB_SIZE - mean * mean - std = tl.sqrt(tl.maximum(var, 0.0)) - threshold = mean + SIGMA * std - - count = 0 - for i in range(NUM_TILES): - offs = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < VOCAB_SIZE - vals = tl.load(LOGITS_ROW + offs, mask=mask, other=0.0) - keep_mask = vals > threshold - - # Write filtered logits - out_vals = tl.where(keep_mask, vals, -float("inf")) - tl.store(OUT_ROW + offs, out_vals, mask=mask) - - # Write kept indices contiguously - new_idx = tl.where(keep_mask, offs + i * BLOCK_SIZE, -1) - kept_idx = new_idx[keep_mask] - num_kept = tl.sum(keep_mask, where=mask) - - # store valid indices sequentially - if num_kept > 0: - write_offs = count + tl.arange(0, num_kept) - tl.store(IDX_ROW + write_offs, kept_idx) - count += num_kept - - # Record number of kept logits - tl.store(L + row_id, count) - - -def apply_top_p_filtered( + # Apply widened cutoff + keep_mask = logits_blk > p_pivot_logit + out_vals = tl.where(keep_mask, logits_blk, -float('inf')) + tl.store(LOGITS_ROW + offs_n, out_vals, mask=mask_n) + + # ====== keeping track of L and indx_tensor ====== + n_kept = tl.sum(keep_mask, dtype=tl.int32) + if n_kept > 0: + cpos = tl.cast(tl.cumsum(keep_mask) - 1 + kept_write_pos, tl.int32) + tl.store(IDX_ROW + cpos, offs_n, mask=keep_mask) + kept_write_pos += n_kept + tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) + else: + IDX_ROW = idx_tensor + row_id * VOCAB_SIZE + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + kept_write_pos = 0 + + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + + n_kept = tl.sum(mask_n, dtype=tl.int32) + if n_kept > 0: + cpos = tl.cast(tl.cumsum(mask_n) - 1 + kept_write_pos, tl.int32) + tl.store(IDX_ROW + cpos, offs_n, mask=mask_n) + kept_write_pos += n_kept + + tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) + +def apply_top_p_filtered ( logits: torch.Tensor, - k: torch.Tensor, + p: torch.Tensor, ) -> torch.Tensor: - - # Applies top-p using filtering - + """ + Applies top-p using pivot-based filtering + """ batch_size, vocab_size = logits.shape - + original_logits = logits.clone() probs = torch.empty_like(logits) + probs_2 = torch.full_like(probs, -float('inf'), device=logits.device) l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) idx_tensor = torch.empty_like(logits, dtype=torch.int) BLOCK_SIZE = 1024 SIGMA = 2.0 + WIDEN_NUM = 90 grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) - top_p_filter_triton[grid]( + top_p_pivot_filter[grid]( logits, probs, + probs_2, l, idx_tensor, - k, + p, batch_size, SIGMA=SIGMA, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, + WIDEN_NUM=WIDEN_NUM ) - max_l = torch.max(l).item() - filtered_logits = probs[:, :max_l] - logits = apply_top_k_only(logits, k) - return logits + max_l = torch.max(l) + + if max_l.item() == 0: + return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + + outliers = idx_tensor[:, :max_l] + + # --- FIX 5: Gather from the *original* logits clone --- + filtered_logits = torch.gather(original_logits, 1, outliers) # shape [B, max_l] + + mask = torch.arange(max_l, device=logits.device).expand(batch_size, -1) < l.unsqueeze(1) + padded_logits = torch.where(mask, filtered_logits, -float('inf')) + + # ====== sort the filtered top p ====== + logits_sort, sort_indices = torch.sort(padded_logits, dim=-1, descending=True) + logits_idx_sorted = torch.gather(outliers, 1, sort_indices) + + if torch.any(p < 1.0): + probs_sort = logits_sort.softmax(dim=-1) # Use dim=-1 for safety + probs_sum = torch.cumsum(probs_sort, dim=-1) + + top_p_mask = probs_sum > p.unsqueeze(dim=1) + top_p_mask[:, 0] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + final_logits = torch.full_like(logits, -float("inf")) + final_logits.scatter_(dim=1, index=logits_idx_sorted, src=logits_sort) + return final_logits + def apply_top_k_top_p_test2( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, ) -> torch.Tensor: - - # Filter out the outliers - + """ + Uses pivot-based algorithm to filter --> sort + """ if p is None: if k is None: return logits return apply_top_k_only(logits, k) - # Apply top-k filter first if needed - if k is not None: - logits = apply_top_k_only(logits, k) - - # Apply top-p using binary search (no sort!) - logits = apply_top_p_filtered(logits, p)""" + top_p_logits = apply_top_p_filtered(logits.clone(), p) + if k is None: + return top_p_logits + top_k_logits = apply_top_k_only(logits, k) + return torch.maximum(top_p_logits, top_k_logits) \ No newline at end of file From d8fac6a2a325d547752cdd4878f38e48e4f0218c Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Sun, 26 Oct 2025 15:00:27 -0700 Subject: [PATCH 39/99] init commit --- compare.py | 20 +-- vllm/v1/sample/ops/topk_topp_sampler.py | 192 +++++++++--------------- 2 files changed, 82 insertions(+), 130 deletions(-) diff --git a/compare.py b/compare.py index 04f5e5676ea2..1e92577156c5 100644 --- a/compare.py +++ b/compare.py @@ -103,7 +103,7 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): vocab_size_list = [4096, 16384] p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] - func_list = [apply_top_k_top_p, apply_top_k_top_p_triton, apply_top_k_top_p_test2] + func_list = [apply_top_k_top_p, apply_top_k_top_p_test2] log_file = f"triton_topk_topp_test_{date_str}.log" csv_file = f"triton_topk_topp_test_{date_str}.csv" @@ -169,20 +169,20 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): print_to_log( b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) print_to_log( - b_str("triton_time_taken: ") + f"{time_list[1]}", + b_str("test_time_taken: ") + f"{time_list[1]}", log_file) + # print_to_log( + # b_str("test_time_taken: ") + f"{time_list[2]}", log_file) print_to_log( - b_str("test_time_taken: ") + f"{time_list[2]}", log_file) - print_to_log( - g_str("Triton Speedup over Torch: ") + + g_str("test Speedup over Torch: ") + f"{time_list[0] / time_list[1]:.8f}x", log_file) - print_to_log( - y_str("Test Speedup over Torch: ") + - f"{time_list[0] / time_list[2]:.8f}x", log_file) + # print_to_log( + # y_str("Test Speedup over Torch: ") + + # f"{time_list[0] / time_list[2]:.8f}x", log_file) with open(csv_file, "a") as f: f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k}," - f"{correct_list[0]},{correct_list[1]},{time_list[0]},{time_list[1]},{time_list[2]}," - f"{time_list[0] / time_list[1]:.8f}, {time_list[0] / time_list[2]:.8f}\n") + f"{correct_list[0]},{time_list[0]}," + f"{time_list[0] / time_list[1]:.8f}\n") print_to_log(y_str("--------------------------------\n"), log_file) """# SPDX-License-Identifier: Apache-2.0 diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index d74d1bc30e5a..e59de16a8a1c 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -894,6 +894,8 @@ def apply_top_k_only( The logits tensor may be updated in-place. """ + if k is None: + return logits max_top_k = k.max().item() # --- FIX: Handle k=0 edge case --- @@ -1233,46 +1235,7 @@ def apply_top_k_top_p_test( # Apply top-p using binary search (no sort!) return apply_top_p_sorted_equivalent(logits, p) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +# -------------------------------------------------------------------------------------- @triton.jit def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr, WIDEN_NUM: tl.constexpr): @@ -1360,7 +1323,7 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con if sum_outlier_probs > p: min_range = outlier_prob - search_addr = PROBS_2_ROW + search_address = PROBS_2_ROW search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) @@ -1429,7 +1392,7 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con force_remove_logit = tl.log( min_larger_0 * sum_exp_logits) + max_logit - p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + # p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit p_pivot_logit = -float('inf') # -------- widen cutoff by 10% ---------------- @@ -1443,57 +1406,44 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con current_num_force_remove = tl.zeros((), dtype=tl.uint32) kept_write_pos = 0 - if p_pivot != -float('inf'): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) - - if force_remove_logit != -float('inf'): - # Force remove duplicates - tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) - force_remove_mask = tl.abs( - logits_blk - force_remove_logit) < tolerance - force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove - force_remove_count_mask = force_remove_count <= num_force_remove - force_remove_mask = force_remove_count_mask & force_remove_mask - logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) - current_num_force_remove = tl.max(force_remove_count) - - # Apply widened cutoff - keep_mask = logits_blk > p_pivot_logit - out_vals = tl.where(keep_mask, logits_blk, -float('inf')) - tl.store(LOGITS_ROW + offs_n, out_vals, mask=mask_n) - - # ====== keeping track of L and indx_tensor ====== - n_kept = tl.sum(keep_mask, dtype=tl.int32) - if n_kept > 0: - cpos = tl.cast(tl.cumsum(keep_mask) - 1 + kept_write_pos, tl.int32) - tl.store(IDX_ROW + cpos, offs_n, mask=keep_mask) - kept_write_pos += n_kept - tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) - else: - IDX_ROW = idx_tensor + row_id * VOCAB_SIZE - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - kept_write_pos = 0 - for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) + # probs_blk = tl.load(PROBS_ROW + offs_n, + # mask=mask_n, + # other=0.0) - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + if force_remove_logit != -float('inf'): + # Force remove duplicates + tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) + force_remove_mask = tl.abs( + logits_blk - force_remove_logit) < tolerance + force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove + force_remove_count_mask = force_remove_count <= num_force_remove + force_remove_mask = force_remove_count_mask & force_remove_mask + logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) + current_num_force_remove = tl.max(force_remove_count) + + # Apply widened cutoff + keep_mask = logits_blk > p_pivot_logit + out_vals = tl.where(keep_mask, logits_blk, -float('inf')) + tl.store(LOGITS_ROW + offs_n, out_vals, mask=mask_n) - n_kept = tl.sum(mask_n, dtype=tl.int32) + # ====== keeping track of L and idx_tensor ====== + n_kept = tl.sum(keep_mask, dtype=tl.int32) if n_kept > 0: - cpos = tl.cast(tl.cumsum(mask_n) - 1 + kept_write_pos, tl.int32) - tl.store(IDX_ROW + cpos, offs_n, mask=mask_n) + cpos = tl.cast(tl.cumsum(keep_mask) - 1 + kept_write_pos, tl.int32) + write_idx = tl.where(keep_mask, cpos, 0) + + tl.store(IDX_ROW + write_idx, offs_n, mask=keep_mask) + # tl.store(PROBS_2_ROW + write_idx, probs_blk, mask=keep_mask) + kept_write_pos += n_kept - tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) - + def apply_top_p_filtered ( logits: torch.Tensor, p: torch.Tensor, @@ -1502,22 +1452,28 @@ def apply_top_p_filtered ( Applies top-p using pivot-based filtering """ batch_size, vocab_size = logits.shape - original_logits = logits.clone() - probs = torch.empty_like(logits) - probs_2 = torch.full_like(probs, -float('inf'), device=logits.device) + + probs = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + probs_2 = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) - idx_tensor = torch.empty_like(logits, dtype=torch.int) + idx_tensor = torch.full_like(logits, 0, dtype=torch.int32) - BLOCK_SIZE = 1024 - SIGMA = 2.0 - WIDEN_NUM = 90 + BLOCK_SIZE = 2048 + SIGMA = 2.15 + NUM_WARPS = 16 + NUM_STAGES = 3 + WIDEN_NUM = 120 + + if not torch.any(p < 1.0): + return logits grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) top_p_pivot_filter[grid]( logits, + l, probs, probs_2, - l, idx_tensor, p, batch_size, @@ -1533,28 +1489,24 @@ def apply_top_p_filtered ( return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) outliers = idx_tensor[:, :max_l] - - # --- FIX 5: Gather from the *original* logits clone --- - filtered_logits = torch.gather(original_logits, 1, outliers) # shape [B, max_l] - - mask = torch.arange(max_l, device=logits.device).expand(batch_size, -1) < l.unsqueeze(1) - padded_logits = torch.where(mask, filtered_logits, -float('inf')) + filtered_logits = torch.gather(logits, 1, outliers) - # ====== sort the filtered top p ====== - logits_sort, sort_indices = torch.sort(padded_logits, dim=-1, descending=True) - logits_idx_sorted = torch.gather(outliers, 1, sort_indices) + filtered_logits_sort, sort_indices = torch.sort( + filtered_logits, dim=-1, descending=False + ) + outliers_sorted = torch.gather(outliers, 1, sort_indices) + filtered_probs_sort = filtered_logits_sort.softmax(dim=-1) - if torch.any(p < 1.0): - probs_sort = logits_sort.softmax(dim=-1) # Use dim=-1 for safety - probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sum = torch.cumsum(filtered_probs_sort, dim=-1, out=filtered_probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False - top_p_mask = probs_sum > p.unsqueeze(dim=1) - top_p_mask[:, 0] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - final_logits = torch.full_like(logits, -float("inf")) - final_logits.scatter_(dim=1, index=logits_idx_sorted, src=logits_sort) - return final_logits + filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) + + logits.fill_(-float("inf")) + logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) + + return logits def apply_top_k_top_p_test2( @@ -1565,12 +1517,12 @@ def apply_top_k_top_p_test2( """ Uses pivot-based algorithm to filter --> sort """ - if p is None: - if k is None: - return logits + if k is None and p is None: + return logits + elif p is None and k is not None: return apply_top_k_only(logits, k) - top_p_logits = apply_top_p_filtered(logits.clone(), p) - if k is None: - return top_p_logits - top_k_logits = apply_top_k_only(logits, k) - return torch.maximum(top_p_logits, top_k_logits) \ No newline at end of file + elif k is None and p is not None: + return apply_top_p_filtered(logits, p) + else: + logits_p = apply_top_p_filtered(logits, p) + return apply_top_k_only(logits_p, k) \ No newline at end of file From 1d349d3109f7d4582fe2f62dcaa1e3f28d676720 Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Sun, 26 Oct 2025 17:12:31 -0700 Subject: [PATCH 40/99] not working......... --- compare.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 68 ++++++++++++++++++++----- 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/compare.py b/compare.py index 1e92577156c5..091695461f87 100644 --- a/compare.py +++ b/compare.py @@ -70,7 +70,7 @@ def test_accuracy(logits, k, p, func_list): torch.sort(original_logits[error_rows], descending=True).values original_logits_to_show = original_logits_to_show[:row_to_show, :20] print_to_log(f"original_logits: {original_logits_to_show}", log_file) - + assert False return output_correct_list diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e59de16a8a1c..3123cd145f1d 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1397,7 +1397,7 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con # -------- widen cutoff by 10% ---------------- if p_pivot != -float('inf'): - # WIDEN_NUM (e.g., 90) / 100.0 = 0.9 + # WIDEN_NUM widened_prob = p_pivot * (WIDEN_NUM / 100.0) # clamp widened_prob to <= max possible prob widened_prob = tl.minimum(widened_prob, max_range) @@ -1412,9 +1412,9 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - # probs_blk = tl.load(PROBS_ROW + offs_n, - # mask=mask_n, - # other=0.0) + probs_blk = tl.load(PROBS_ROW + offs_n, + mask=mask_n, + other=0.0) if force_remove_logit != -float('inf'): # Force remove duplicates @@ -1439,7 +1439,7 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con write_idx = tl.where(keep_mask, cpos, 0) tl.store(IDX_ROW + write_idx, offs_n, mask=keep_mask) - # tl.store(PROBS_2_ROW + write_idx, probs_blk, mask=keep_mask) + tl.store(PROBS_2_ROW + write_idx, probs_blk, mask=keep_mask) kept_write_pos += n_kept tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) @@ -1451,19 +1451,26 @@ def apply_top_p_filtered ( """ Applies top-p using pivot-based filtering """ + # logits = torch.ones((10,), device=logits.device, dtype=torch.float32).view(1, -1) + logits_copy = logits.clone().detach() + # p = torch.full((logits.shape[0],), 0.65, dtype=torch.float32, device=logits.device) + # output = apply_top_k_top_p(logits_copy, None, p) + # print(f"original value = {output}") batch_size, vocab_size = logits.shape + # print(f"logits: {logits}", flush=True) probs = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) probs_2 = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) idx_tensor = torch.full_like(logits, 0, dtype=torch.int32) + BLOCK_SIZE = 2048 SIGMA = 2.15 NUM_WARPS = 16 NUM_STAGES = 3 - WIDEN_NUM = 120 + WIDEN_NUM = 0 if not torch.any(p < 1.0): return logits @@ -1482,30 +1489,67 @@ def apply_top_p_filtered ( BLOCK_SIZE=BLOCK_SIZE, WIDEN_NUM=WIDEN_NUM ) + # print(f"logits: {logits}", flush=True) + # print(f"l: {l}", flush=True) + # print(f"probs: {probs}", flush=True) + # print(f"probs_2: {probs_2}", flush=True) + # print(f"idx_tensor: {idx_tensor}", flush=True) max_l = torch.max(l) - if max_l.item() == 0: - return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + # if max_l.item() == 0: + # return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + + + # outliers = torch.arange(0, vocab_size, dtype=torch.int32, device=logits.device).unsqueeze(0).expand(logits.shape[0], -1) + # probs = torch.softmax(logits, dim=1) outliers = idx_tensor[:, :max_l] + full_probs = logits_copy.softmax(dim=-1) + filtered_full_probs = torch.gather(full_probs, 1, outliers) filtered_logits = torch.gather(logits, 1, outliers) + # print(f"outliers: {outliers}", flush=True) + # print(f"filtered_logits: {filtered_logits}", flush=True) + + probs = torch.gather(probs, 1, outliers) + cum_sum = torch.sum(probs, dim=1) + + # print(f"sum : {cum_sum}", flush=True) + filtered_logits_sort, sort_indices = torch.sort( filtered_logits, dim=-1, descending=False ) + + + + # print(f"filtered_logits_sort: {filtered_logits_sort}", flush=True) + outliers_sorted = torch.gather(outliers, 1, sort_indices) - filtered_probs_sort = filtered_logits_sort.softmax(dim=-1) + filtered_probs_sort = torch.gather(filtered_full_probs, 1, sort_indices) + + # print(f"outliers_sorted: {outliers_sorted}", flush=True) + # print(f"filtered_probs_sort: {filtered_logits_sort}", flush=True) + probs_sum = torch.cumsum(filtered_probs_sort, dim=-1, out=filtered_probs_sort) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) top_p_mask[:, -1] = False + # print(f"probs_sum: {probs_sum}") + # print(f"top_p_mask = {top_p_mask}") + filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) + # print(f"top_p_mask = {filtered_logits_sort}") + + logits.fill_(-float("inf")) logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) - + + # print(f"final logits after scatter = {logits}") + + # assert False return logits @@ -1524,5 +1568,5 @@ def apply_top_k_top_p_test2( elif k is None and p is not None: return apply_top_p_filtered(logits, p) else: - logits_p = apply_top_p_filtered(logits, p) - return apply_top_k_only(logits_p, k) \ No newline at end of file + logits_k = apply_top_k_only(logits, k) + return apply_top_p_filtered(logits, p) \ No newline at end of file From b9a0c05304dd6ec87d43ca3579c0aa0c17da4387 Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Sun, 26 Oct 2025 17:19:58 -0700 Subject: [PATCH 41/99] working on it.... --- compare.py | 3 ++- vllm/v1/sample/ops/topk_topp_sampler.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/compare.py b/compare.py index 091695461f87..fd51b473e210 100644 --- a/compare.py +++ b/compare.py @@ -102,7 +102,8 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): batch_size_list = [64, 128, 1024] vocab_size_list = [4096, 16384] p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] - k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + # k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + k_list = [None] func_list = [apply_top_k_top_p, apply_top_k_top_p_test2] log_file = f"triton_topk_topp_test_{date_str}.log" diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 3123cd145f1d..14b94f582b12 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1416,16 +1416,16 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con mask=mask_n, other=0.0) - if force_remove_logit != -float('inf'): - # Force remove duplicates - tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) - force_remove_mask = tl.abs( - logits_blk - force_remove_logit) < tolerance - force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove - force_remove_count_mask = force_remove_count <= num_force_remove - force_remove_mask = force_remove_count_mask & force_remove_mask - logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) - current_num_force_remove = tl.max(force_remove_count) + # if force_remove_logit != -float('inf'): + # # Force remove duplicates + # tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) + # force_remove_mask = tl.abs( + # logits_blk - force_remove_logit) < tolerance + # force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove + # force_remove_count_mask = force_remove_count <= num_force_remove + # force_remove_mask = force_remove_count_mask & force_remove_mask + # logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) + # current_num_force_remove = tl.max(force_remove_count) # Apply widened cutoff keep_mask = logits_blk > p_pivot_logit From 71c59786f9d5701c43664e7b6988816a089e3566 Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Sun, 26 Oct 2025 17:32:30 -0700 Subject: [PATCH 42/99] working........python compare.py --- vllm/v1/sample/ops/topk_topp_sampler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 14b94f582b12..23e4bfa56d24 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1382,9 +1382,9 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con p_pivot = p_pivot_0 # At least one value should be greater than p_pivot - if p_pivot >= max_logit: - p_pivot = second_max_logit - elif num_min_larger_0 > 1: + # if p_pivot >= max_logit: + # p_pivot = second_max_logit + if True: # Force remove duplicates (p_pivot is made to include all # duplicates if it falls on the duplicates) num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, @@ -1470,7 +1470,7 @@ def apply_top_p_filtered ( SIGMA = 2.15 NUM_WARPS = 16 NUM_STAGES = 3 - WIDEN_NUM = 0 + WIDEN_NUM = 0 # ----------------------------> ???????? if not torch.any(p < 1.0): return logits @@ -1482,7 +1482,7 @@ def apply_top_p_filtered ( probs, probs_2, idx_tensor, - p, + p*1.1, batch_size, SIGMA=SIGMA, VOCAB_SIZE=vocab_size, From 115a98b3368337dfab0f32714e0bc4f7581d6cd8 Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Tue, 28 Oct 2025 10:19:31 -0700 Subject: [PATCH 43/99] ... --- vllm/v1/sample/ops/topk_topp_sampler.py | 32 +++++++++---------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 23e4bfa56d24..90ffb3e066db 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1238,7 +1238,7 @@ def apply_top_k_top_p_test( # -------------------------------------------------------------------------------------- @triton.jit -def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr, WIDEN_NUM: tl.constexpr): +def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr): NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) @@ -1381,29 +1381,20 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: p_pivot = p_pivot_0 - # At least one value should be greater than p_pivot - # if p_pivot >= max_logit: - # p_pivot = second_max_logit - if True: - # Force remove duplicates (p_pivot is made to include all - # duplicates if it falls on the duplicates) - num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, - tl.uint32) - force_remove_logit = tl.log( - min_larger_0 * sum_exp_logits) + max_logit + # Force remove duplicates (p_pivot is made to include all + # duplicates if it falls on the duplicates) + num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, + tl.uint32) + force_remove_logit = tl.log( + min_larger_0 * sum_exp_logits) + max_logit # p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit p_pivot_logit = -float('inf') # -------- widen cutoff by 10% ---------------- if p_pivot != -float('inf'): - # WIDEN_NUM - widened_prob = p_pivot * (WIDEN_NUM / 100.0) - # clamp widened_prob to <= max possible prob - widened_prob = tl.minimum(widened_prob, max_range) p_pivot_logit = tl.log(widened_prob * sum_exp_logits) + max_logit - - current_num_force_remove = tl.zeros((), dtype=tl.uint32) + # current_num_force_remove = tl.zeros((), dtype=tl.uint32) kept_write_pos = 0 for i in range(0, NUM_TILES): @@ -1465,16 +1456,16 @@ def apply_top_p_filtered ( l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) idx_tensor = torch.full_like(logits, 0, dtype=torch.int32) - BLOCK_SIZE = 2048 SIGMA = 2.15 NUM_WARPS = 16 NUM_STAGES = 3 - WIDEN_NUM = 0 # ----------------------------> ???????? if not torch.any(p < 1.0): return logits + p_widened = torch.clamp(p * 1.2, max=0.999) + grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) top_p_pivot_filter[grid]( logits, @@ -1482,12 +1473,11 @@ def apply_top_p_filtered ( probs, probs_2, idx_tensor, - p*1.1, + p_widened, batch_size, SIGMA=SIGMA, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, - WIDEN_NUM=WIDEN_NUM ) # print(f"logits: {logits}", flush=True) # print(f"l: {l}", flush=True) From f9b08f2232b0f0cfcd3138c82da90960874f3718 Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Tue, 28 Oct 2025 12:53:00 -0700 Subject: [PATCH 44/99] ... --- vllm/v1/sample/ops/topk_topp_sampler.py | 216 +++++++++++++++++++----- 1 file changed, 176 insertions(+), 40 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 90ffb3e066db..db1dac2ba502 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1393,7 +1393,7 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con # -------- widen cutoff by 10% ---------------- if p_pivot != -float('inf'): - p_pivot_logit = tl.log(widened_prob * sum_exp_logits) + max_logit + p_pivot_logit = tl.log(p_pivot * sum_exp_logits) + max_logit # current_num_force_remove = tl.zeros((), dtype=tl.uint32) kept_write_pos = 0 @@ -1418,7 +1418,6 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con # logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) # current_num_force_remove = tl.max(force_remove_count) - # Apply widened cutoff keep_mask = logits_blk > p_pivot_logit out_vals = tl.where(keep_mask, logits_blk, -float('inf')) tl.store(LOGITS_ROW + offs_n, out_vals, mask=mask_n) @@ -1442,16 +1441,11 @@ def apply_top_p_filtered ( """ Applies top-p using pivot-based filtering """ - # logits = torch.ones((10,), device=logits.device, dtype=torch.float32).view(1, -1) logits_copy = logits.clone().detach() - # p = torch.full((logits.shape[0],), 0.65, dtype=torch.float32, device=logits.device) - # output = apply_top_k_top_p(logits_copy, None, p) - # print(f"original value = {output}") batch_size, vocab_size = logits.shape - # print(f"logits: {logits}", flush=True) probs = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) - probs_2 = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + probs_2 = torch.zeros((batch_size, vocab_size), device=logits.device) l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) idx_tensor = torch.full_like(logits, 0, dtype=torch.int32) @@ -1464,7 +1458,7 @@ def apply_top_p_filtered ( if not torch.any(p < 1.0): return logits - p_widened = torch.clamp(p * 1.2, max=0.999) + p_widened = torch.clamp(p * 1.5, max=0.999) grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) top_p_pivot_filter[grid]( @@ -1479,67 +1473,81 @@ def apply_top_p_filtered ( VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, ) - # print(f"logits: {logits}", flush=True) - # print(f"l: {l}", flush=True) - # print(f"probs: {probs}", flush=True) - # print(f"probs_2: {probs_2}", flush=True) - # print(f"idx_tensor: {idx_tensor}", flush=True) max_l = torch.max(l) # if max_l.item() == 0: # return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) - - - # outliers = torch.arange(0, vocab_size, dtype=torch.int32, device=logits.device).unsqueeze(0).expand(logits.shape[0], -1) - # probs = torch.softmax(logits, dim=1) - + outliers = idx_tensor[:, :max_l] + + #this is for gather, which puts in garbage value for those not included. I wanted to mask it out. + valid_mask = torch.arange(max_l, device=logits.device).unsqueeze(0) < l.unsqueeze(1) + print(f"p: {p[:3]}") + print(f"outliers shape: {outliers.shape}", flush=True) + print(f"outliers[:3]: {outliers[:3]}", flush=True) + print(f"l[:3]: {l[:3]}", flush=True) + print(f"valid mask: {valid_mask}") + full_probs = logits_copy.softmax(dim=-1) filtered_full_probs = torch.gather(full_probs, 1, outliers) filtered_logits = torch.gather(logits, 1, outliers) - # print(f"outliers: {outliers}", flush=True) - # print(f"filtered_logits: {filtered_logits}", flush=True) - probs = torch.gather(probs, 1, outliers) - cum_sum = torch.sum(probs, dim=1) + filtered_probs_sort = torch.where(valid_mask, filtered_full_probs, 0.0) + filtered_logits = torch.where(valid_mask, filtered_logits, torch.tensor(-float('inf'), device=logits.device)) - # print(f"sum : {cum_sum}", flush=True) - + print(f"filtered_probs[:3]: {filtered_probs_sort[:3]}", flush=True) + print(f"filtered_logits[:3]: {filtered_logits[:3]}", flush=True) filtered_logits_sort, sort_indices = torch.sort( filtered_logits, dim=-1, descending=False ) - - - # print(f"filtered_logits_sort: {filtered_logits_sort}", flush=True) - outliers_sorted = torch.gather(outliers, 1, sort_indices) - filtered_probs_sort = torch.gather(filtered_full_probs, 1, sort_indices) + filtered_probs_sort = torch.gather(filtered_probs_sort, 1, sort_indices) + valid_mask_sorted = torch.gather(valid_mask, 1, sort_indices) - # print(f"outliers_sorted: {outliers_sorted}", flush=True) - # print(f"filtered_probs_sort: {filtered_logits_sort}", flush=True) + print(f"outliers_sorted: {outliers_sorted}", flush=True) + print(f"filtered_probs_sort: {filtered_logits_sort}", flush=True) + probs_sum = torch.cumsum(filtered_probs_sort, dim=-1) - probs_sum = torch.cumsum(filtered_probs_sort, dim=-1, out=filtered_probs_sort) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + top_p_mask = top_p_mask | ~valid_mask_sorted top_p_mask[:, -1] = False - # print(f"probs_sum: {probs_sum}") - # print(f"top_p_mask = {top_p_mask}") + print(f"probs_sum: {probs_sum}") + print(f"top_p_mask = {top_p_mask}") + + print(f"p[0]: {p[0].item():.4f}") + print(f"1 - p[0]: {(1 - p[0]).item():.4f}") + print(f"Batch 0 probs_sum[-3:]: {probs_sum[0, -3:]}") + print(f"Batch 0 top_p_mask[-3:]: {top_p_mask[0, -3:]}") filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) - # print(f"top_p_mask = {filtered_logits_sort}") + print(f"top_p_mask = {filtered_logits_sort}") logits.fill_(-float("inf")) logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) - # print(f"final logits after scatter = {logits}") - - # assert False + print(f"final logits after scatter = {logits}") + + # ========================= error ============================================ + error_batches = [11, 20, 47, 50, 51, 59, 60, 61] + for i in error_batches[:3]: + print(f"\n=== Batch {i} ===") + print(f"l[{i}]: {l[i].item()}") + print(f"p[{i}]: {p[i].item():.4f}") + print(f"1-p[{i}]: {(1-p[i]).item():.4f}") + print(f"probs_sum[{i}, -5:]: {probs_sum[i, -5:]}") + print(f"filtered_probs_sort[{i}, -5:]: {filtered_probs_sort[i, -5:]}") + + prev_sum = probs_sum[i] - filtered_probs_sort[i] + print(f"prev_sum[{i}, -5:]: {prev_sum[-5:]}") + print(f"mask check: {(prev_sum > (1 - p[i]))[-5:]}") + print(f"top_p_mask[{i}, -5:]: {top_p_mask[i, -5:]}") return logits @@ -1559,4 +1567,132 @@ def apply_top_k_top_p_test2( return apply_top_p_filtered(logits, p) else: logits_k = apply_top_k_only(logits, k) - return apply_top_p_filtered(logits, p) \ No newline at end of file + return apply_top_p_filtered(logits, p) + + +# def apply_top_p_filtered ( +# logits: torch.Tensor, +# p: torch.Tensor, +# ) -> torch.Tensor: +# """ +# Applies top-p using pivot-based filtering +# """ +# # logits = torch.ones((10,), device=logits.device, dtype=torch.float32).view(1, -1) +# logits_copy = logits.clone().detach() +# # p = torch.full((logits.shape[0],), 0.65, dtype=torch.float32, device=logits.device) +# # output = apply_top_k_top_p(logits_copy, None, p) +# # print(f"original value = {output}") +# batch_size, vocab_size = logits.shape + +# # print(f"logits: {logits}", flush=True) +# probs = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) +# probs_2 = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + +# l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) +# idx_tensor = torch.full_like(logits, 0, dtype=torch.int32) + + +# BLOCK_SIZE = 2048 +# SIGMA = 2.15 +# NUM_WARPS = 16 +# NUM_STAGES = 3 +# WIDEN_NUM = 0 # ----------------------------> ???????? + +# if not torch.any(p < 1.0): +# return logits + +# grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) +# top_p_pivot_filter[grid]( +# logits, +# l, +# probs, +# probs_2, +# idx_tensor, +# p*1.1, +# batch_size, +# SIGMA=SIGMA, +# VOCAB_SIZE=vocab_size, +# BLOCK_SIZE=BLOCK_SIZE, +# WIDEN_NUM=WIDEN_NUM +# ) +# # print(f"logits: {logits}", flush=True) +# # print(f"l: {l}", flush=True) +# # print(f"probs: {probs}", flush=True) +# # print(f"probs_2: {probs_2}", flush=True) +# # print(f"idx_tensor: {idx_tensor}", flush=True) + +# max_l = torch.max(l) + +# # if max_l.item() == 0: +# # return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + + +# # outliers = torch.arange(0, vocab_size, dtype=torch.int32, device=logits.device).unsqueeze(0).expand(logits.shape[0], -1) +# # probs = torch.softmax(logits, dim=1) + +# outliers = idx_tensor[:, :max_l] +# full_probs = logits_copy.softmax(dim=-1) +# filtered_full_probs = torch.gather(full_probs, 1, outliers) +# filtered_logits = torch.gather(logits, 1, outliers) +# # print(f"outliers: {outliers}", flush=True) +# # print(f"filtered_logits: {filtered_logits}", flush=True) + +# probs = torch.gather(probs, 1, outliers) +# cum_sum = torch.sum(probs, dim=1) + +# # print(f"sum : {cum_sum}", flush=True) + + +# filtered_logits_sort, sort_indices = torch.sort( +# filtered_logits, dim=-1, descending=False +# ) + + + +# # print(f"filtered_logits_sort: {filtered_logits_sort}", flush=True) + +# outliers_sorted = torch.gather(outliers, 1, sort_indices) +# filtered_probs_sort = torch.gather(filtered_full_probs, 1, sort_indices) + +# # print(f"outliers_sorted: {outliers_sorted}", flush=True) +# # print(f"filtered_probs_sort: {filtered_logits_sort}", flush=True) + + +# probs_sum = torch.cumsum(filtered_probs_sort, dim=-1, out=filtered_probs_sort) +# top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) +# top_p_mask[:, -1] = False + +# # print(f"probs_sum: {probs_sum}") +# # print(f"top_p_mask = {top_p_mask}") + +# filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) + +# # print(f"top_p_mask = {filtered_logits_sort}") + + +# logits.fill_(-float("inf")) +# logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) + +# # print(f"final logits after scatter = {logits}") + +# # assert False +# return logits + + +# def apply_top_k_top_p_test2( +# logits: torch.Tensor, +# k: torch.Tensor | None, +# p: torch.Tensor | None, +# ) -> torch.Tensor: +# """ +# Uses pivot-based algorithm to filter --> sort +# """ +# if k is None and p is None: +# return logits +# elif p is None and k is not None: +# return apply_top_k_only(logits, k) +# elif k is None and p is not None: +# return apply_top_p_filtered(logits, p) +# else: +# logits_k = apply_top_k_only(logits, k) +# return apply_top_p_filtered(logits, p) \ No newline at end of file From b8728dbc7da32276f0a1d6e152ed161e36f4587f Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Sat, 1 Nov 2025 15:24:10 -0700 Subject: [PATCH 45/99] slow but working --- compare.py | 16 ++- vllm/v1/sample/ops/topk_topp_sampler.py | 148 +++++++++++++++--------- 2 files changed, 110 insertions(+), 54 deletions(-) diff --git a/compare.py b/compare.py index fd51b473e210..0e7e959aebd4 100644 --- a/compare.py +++ b/compare.py @@ -2,10 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from datetime import datetime from itertools import product - -import regex as re import torch +torch.manual_seed(42) +torch.cuda.manual_seed(42) +torch.cuda.manual_seed_all(42) +import random +import numpy as np +random.seed(42) +np.random.seed(42) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +import regex as re +import torch + from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, apply_top_k_top_p_triton, apply_top_k_top_p_test2 @@ -51,6 +62,7 @@ def test_accuracy(logits, k, p, func_list): if not is_correct: print_to_log(r_str(f"Error: logits are not close on {i} - " + f"{func_name}"), log_file) + output_logits = apply_top_k_top_p_test2(logits, k, p, debug=True) error_mask = torch.abs(output_logits - original_logits) > 1e-5 error_rows = torch.where(error_mask)[0] error_rows = torch.unique(error_rows) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index db1dac2ba502..63b87323b099 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1238,14 +1238,15 @@ def apply_top_k_top_p_test( # -------------------------------------------------------------------------------------- @triton.jit -def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr): +def top_p_pivot_filter(LOGITS, L, MIN_IDX, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr): NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): p = tl.load(P + row_id) # fetches the p value of the row it is working on - if p != 1.0: # if p == 1, this becomes pointless ! + # if p != 1.0: # if p == 1, this becomes pointless ! + if True: p_pivot = -float('inf') LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE @@ -1259,6 +1260,7 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con max_logit = -float('inf') min_logit = float('inf') + min_logit_idx = -1 force_remove_logit = -float('inf') # for handling duplicate cases (edge case) num_force_remove = tl.zeros((), dtype=tl.uint32) @@ -1284,7 +1286,18 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con mask=mask_n, other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=float('inf')) + local_min_logit = tl.min(logits_blk) + local_min_logit_idx = tl.argmin(logits_blk, axis=0) + if local_min_logit < min_logit: + min_logit = local_min_logit + min_logit_idx = local_min_logit_idx + i * BLOCK_SIZE # ====== Third pass: Calculate exp logits and sum ====== for i in range(0, search_iters): @@ -1433,14 +1446,21 @@ def top_p_pivot_filter(LOGITS, L, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.con kept_write_pos += n_kept tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) + tl.store(MIN_IDX + row_id, tl.cast(min_logit_idx, tl.int32)) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n >= kept_write_pos + tl.store(IDX_ROW + offs_n, min_logit_idx, mask=mask_n) def apply_top_p_filtered ( logits: torch.Tensor, p: torch.Tensor, + debug: bool = False ) -> torch.Tensor: """ Applies top-p using pivot-based filtering """ + logits_copy = logits.clone().detach() batch_size, vocab_size = logits.shape @@ -1448,22 +1468,25 @@ def apply_top_p_filtered ( probs_2 = torch.zeros((batch_size, vocab_size), device=logits.device) l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) - idx_tensor = torch.full_like(logits, 0, dtype=torch.int32) + min_idx = torch.zeros_like(l, dtype=torch.int32) + idx_tensor = torch.full_like(logits, -1, dtype=torch.int32) BLOCK_SIZE = 2048 SIGMA = 2.15 NUM_WARPS = 16 NUM_STAGES = 3 - if not torch.any(p < 1.0): - return logits + # if not torch.any(p < 1.0): + # return logits - p_widened = torch.clamp(p * 1.5, max=0.999) + p_widened = torch.clamp(p * 1.2, max=0.999) + # p_widened = torch.ones_like(p) grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) top_p_pivot_filter[grid]( logits, l, + min_idx, probs, probs_2, idx_tensor, @@ -1475,79 +1498,98 @@ def apply_top_p_filtered ( ) max_l = torch.max(l) + error_row = 859 if batch_size > 1000 else 2 + print(f"p = {p[error_row]}") # if max_l.item() == 0: # return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) outliers = idx_tensor[:, :max_l] - + #this is for gather, which puts in garbage value for those not included. I wanted to mask it out. valid_mask = torch.arange(max_l, device=logits.device).unsqueeze(0) < l.unsqueeze(1) - print(f"p: {p[:3]}") - print(f"outliers shape: {outliers.shape}", flush=True) - print(f"outliers[:3]: {outliers[:3]}", flush=True) - print(f"l[:3]: {l[:3]}", flush=True) - print(f"valid mask: {valid_mask}") + + print(f"\n=== STEP 1: After gathering outliers ===") + print(f"outliers[{error_row-2}:{error_row+2}, :10]: {outliers[error_row-2:error_row+2, :10]}") + print(f"outliers[{error_row-2}:{error_row+2}, -10:]: {outliers[error_row-2:error_row+2, -10:]}") full_probs = logits_copy.softmax(dim=-1) + filtered_full_probs = torch.gather(full_probs, 1, outliers) filtered_logits = torch.gather(logits, 1, outliers) filtered_probs_sort = torch.where(valid_mask, filtered_full_probs, 0.0) filtered_logits = torch.where(valid_mask, filtered_logits, torch.tensor(-float('inf'), device=logits.device)) - print(f"filtered_probs[:3]: {filtered_probs_sort[:3]}", flush=True) - print(f"filtered_logits[:3]: {filtered_logits[:3]}", flush=True) + print(f"\n=== STEP 2: After gathering logits ===") + print(f"filtered_logits[{error_row-2}:{error_row+2}, :10]: {filtered_logits[error_row-2:error_row+2, :10]}") + print(f"filtered_logits[{error_row-2}:{error_row+2}, -10:]: {filtered_logits[error_row-2:error_row+2, -10:]}") + + + non_outliers_mask = torch.ones_like(full_probs, dtype=torch.bool) + non_outliers_mask.scatter_(1, outliers, False) # False = pivoted tokens + sum_non_outliers = full_probs.masked_fill(~non_outliers_mask, 0.0).sum(dim=1) + print (f"min index shape = {min_idx.shape}") + pytorch_min_prob = torch.min(full_probs, dim=1).values + print (f"pytorch min prob = {pytorch_min_prob[error_row-2:error_row+2]}") + + sum_non_outliers += pytorch_min_prob + + + print (f"full_probs = {full_probs.sum(dim=1)[error_row-2:error_row+2]}") + print (f"sum non outliers = {sum_non_outliers[error_row-2:error_row+2]}") + print (f"min index probs = {full_probs[error_row-2:error_row+2, min_idx[error_row-2:error_row+2]]}") + + + # if vocab_size > 4000: + # print (f" min index probs = {full_probs[error_row-2:error_row+2, 1975]}") + filtered_logits_sort, sort_indices = torch.sort( filtered_logits, dim=-1, descending=False ) + print(f"\n=== STEP 3: After sorting ===") + print(f"filtered_logits_sort[{error_row-2}:{error_row+2}, -10:]: {filtered_logits_sort[error_row-2:error_row+2, -10:]}") + outliers_sorted = torch.gather(outliers, 1, sort_indices) + + print(f"\n=== STEP 4: outliers_sorted ===") + print(f"outliers_sorted[{error_row-2}:{error_row+2}, :1]: {outliers_sorted[error_row-2:error_row+2, :10]}") + filtered_probs_sort = torch.gather(filtered_probs_sort, 1, sort_indices) valid_mask_sorted = torch.gather(valid_mask, 1, sort_indices) - - print(f"outliers_sorted: {outliers_sorted}", flush=True) - print(f"filtered_probs_sort: {filtered_logits_sort}", flush=True) - probs_sum = torch.cumsum(filtered_probs_sort, dim=-1) - + probs_sum = sum_non_outliers.unsqueeze(1) + torch.cumsum(filtered_probs_sort, dim=-1) + print(f"========== probs sum ============= : {probs_sum[error_row-2:error_row+2, -10:]}") top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) top_p_mask = top_p_mask | ~valid_mask_sorted top_p_mask[:, -1] = False + print (f"top p mask {top_p_mask[error_row-2:error_row+2, -10:]}") + filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) - print(f"probs_sum: {probs_sum}") - print(f"top_p_mask = {top_p_mask}") - - print(f"p[0]: {p[0].item():.4f}") - print(f"1 - p[0]: {(1 - p[0]).item():.4f}") - print(f"Batch 0 probs_sum[-3:]: {probs_sum[0, -3:]}") - print(f"Batch 0 top_p_mask[-3:]: {top_p_mask[0, -3:]}") - - filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) - - print(f"top_p_mask = {filtered_logits_sort}") - - - logits.fill_(-float("inf")) + logits.fill_(-float("inf")) logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) - print(f"final logits after scatter = {logits}") - - # ========================= error ============================================ - error_batches = [11, 20, 47, 50, 51, 59, 60, 61] - for i in error_batches[:3]: - print(f"\n=== Batch {i} ===") - print(f"l[{i}]: {l[i].item()}") - print(f"p[{i}]: {p[i].item():.4f}") - print(f"1-p[{i}]: {(1-p[i]).item():.4f}") - print(f"probs_sum[{i}, -5:]: {probs_sum[i, -5:]}") - print(f"filtered_probs_sort[{i}, -5:]: {filtered_probs_sort[i, -5:]}") + + + # if True: + # print(f"\n=== DEBUG ROW 11 ===") + # print(f"l[11]: {l[11]}") + # print(f"valid tokens count: {valid_mask[11].sum()}") + # print(f"outliers_sorted[11, -10:]: {outliers_sorted[11, -10:]}") + # print(f"filtered_logits_sort[11, -10:]: {filtered_logits_sort[11, -10:]}") + # print(f"filtered_probs_sort[11, -10:]: {filtered_probs_sort[11, -10:]}") + # print(f"probs_sum[11, -10:]: {probs_sum[11, -10:]}") + # print(f"top_p_mask[11, -10:]: {top_p_mask[11, -10:]}") + # print(f"valid_mask_sorted[11, -10:]: {valid_mask_sorted[11, -10:]}") + + # print(f"\nBefore scatter, logits[11, :10]: {logits[11, :10]}") - prev_sum = probs_sum[i] - filtered_probs_sort[i] - print(f"prev_sum[{i}, -5:]: {prev_sum[-5:]}") - print(f"mask check: {(prev_sum > (1 - p[i]))[-5:]}") - print(f"top_p_mask[{i}, -5:]: {top_p_mask[i, -5:]}") + # non_inf_mask = logits[11] != -float('inf') + # non_inf_indices = torch.where(non_inf_mask)[0] + # print(f"After scatter, non-inf indices: {non_inf_indices}") + # print(f"After scatter, non-inf values: {logits[11, non_inf_indices]}") return logits @@ -1555,10 +1597,13 @@ def apply_top_k_top_p_test2( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, + debug: bool = False ) -> torch.Tensor: """ Uses pivot-based algorithm to filter --> sort """ + if p.max() > 0.99: + return apply_top_k_top_p(logits, None, p) if k is None and p is None: return logits elif p is None and k is not None: @@ -1567,9 +1612,8 @@ def apply_top_k_top_p_test2( return apply_top_p_filtered(logits, p) else: logits_k = apply_top_k_only(logits, k) - return apply_top_p_filtered(logits, p) - - + return apply_top_p_filtered(logits_k, p, debug) + # def apply_top_p_filtered ( # logits: torch.Tensor, # p: torch.Tensor, From 953025e0d305d236cf8c03880d27557cc2a93c4a Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Thu, 13 Nov 2025 01:41:03 -0800 Subject: [PATCH 46/99] very slow --- compare.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 879 ++++++++++++++---------- 2 files changed, 504 insertions(+), 377 deletions(-) diff --git a/compare.py b/compare.py index 0e7e959aebd4..e7586acd48ec 100644 --- a/compare.py +++ b/compare.py @@ -113,7 +113,7 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): batch_size_list = [64, 128, 1024] vocab_size_list = [4096, 16384] - p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] + p_list = [None, 0.4, 0.7, 0.9, 0.95, 0.99] # k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] k_list = [None] func_list = [apply_top_k_top_p, apply_top_k_top_p_test2] diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 63b87323b099..3e257b9dacdb 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1235,508 +1235,635 @@ def apply_top_k_top_p_test( # Apply top-p using binary search (no sort!) return apply_top_p_sorted_equivalent(logits, p) -# -------------------------------------------------------------------------------------- - -@triton.jit -def top_p_pivot_filter(LOGITS, L, MIN_IDX, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr): +# -------------------------------------------------------------------------- +@triton.jit +def top_p_pivot_filter( + LOGITS, + PROBS, + PROBS_IDX, + K_FILTER: tl.int32, + OUTPUT_LOGITS, + OUTPUT_INDICES, + B, # --> batch size + DEBUG_K_PIVOT, + DEBUG_WRITE_POS, + DEBUG_NUM_OUTLIERS, + SIGMA: tl.constexpr, + VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + pid = tl.program_id(0) num_programs = tl.num_programs(0) - for row_id in tl.range(pid, B, num_programs): - p = tl.load(P + row_id) # fetches the p value of the row it is working on - # if p != 1.0: # if p == 1, this becomes pointless ! - if True: - p_pivot = -float('inf') + for row_id in tl.range(pid, B, num_programs): + k = K_FILTER + if k <= VOCAB_SIZE: + k_pivot = -float('inf') - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE PROBS_ROW = PROBS + row_id * VOCAB_SIZE - PROBS_2_ROW = PROBS_2 + row_id * VOCAB_SIZE - IDX_ROW = idx_tensor + row_id * VOCAB_SIZE + PROBS_IDX_ROW = PROBS_IDX + row_id * VOCAB_SIZE + OUTPUT_LOGITS_ROW = OUTPUT_LOGITS + row_id * K_FILTER + OUTPUT_INDICES_ROW = OUTPUT_INDICES + row_id * K_FILTER - search_address = PROBS_ROW - search_range = VOCAB_SIZE + search_addr = LOGITS_ROW + search_range = VOCAB_SIZE search_iters = NUM_TILES - max_logit = -float('inf') - min_logit = float('inf') - min_logit_idx = -1 - - force_remove_logit = -float('inf') # for handling duplicate cases (edge case) - num_force_remove = tl.zeros((), dtype=tl.uint32) + max_logit = -float('inf') + min_logit = float('inf') - # First Pass: Compute avg and std from a sample block + # Zeroth pass: Compute avg and std from a sample block offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE + mask_n = offs < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / VOCAB_SIZE - sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE + probs_blk = tl.load(PROBS_ROW + offs, mask=mask_n, other=0.0) + valid_count = tl.sum(mask_n, dtype=tl.float32) + avg_logit = tl.sum(logits_blk) / VOCAB_SIZE # re-check + sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE # re-check std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) outlier_pivot = avg_logit + SIGMA * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) - sum_outlier_probs = 0.0 - sum_exp_logits = 0.0 - # ====== Second Pass: compute max and min logits ====== for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(LOGITS_ROW + offs_n, + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=float('inf')) - local_min_logit = tl.min(logits_blk) - local_min_logit_idx = tl.argmin(logits_blk, axis=0) - if local_min_logit < min_logit: - min_logit = local_min_logit - min_logit_idx = local_min_logit_idx + i * BLOCK_SIZE - - # ====== Third pass: Calculate exp logits and sum ====== - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - - probs_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits - - # ====== Fourth pass: Calculate probs and get outliers ====== - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - probs_blk = probs_blk / sum_exp_logits - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - outlier_mask = (probs_blk > outlier_prob) & mask_n - sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + outlier_mask = (logits_blk > outlier_pivot) & mask_n num_blk_outliers = tl.sum(outlier_mask) cumulative_pos = tl.cast( tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) # stores the final probs after masking to PROBS_2 - max_range = tl.exp(max_logit - max_logit) / sum_exp_logits - min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + write_idx = tl.where(outlier_mask, cumulative_pos, 0) + tl.store(PROBS_ROW + write_idx, logits_blk, mask=outlier_mask) + tl.store(PROBS_IDX_ROW + write_idx, offs_n, mask=outlier_mask) - if sum_outlier_probs > p: - min_range = outlier_prob - search_address = PROBS_2_ROW + max_range = max_logit + min_range = min_logit + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_addr = PROBS_ROW search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - - second_max_logit = -float('inf') - + + # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - # ====== Fifth Passes: Search for p_pivot(2log_2(n)) ====== - while p_pivot == -float('inf') and num_iters < 32: - p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range - p_pivots_sum_0 = 0.0 - - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + tl.store(OUTPUT_LOGITS_ROW, 12345.0) + while k_pivot == -float('inf') and num_iters < 18: + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_address + offs_n, - mask=mask_n, - other=0.0) - - masked_larger_0 = tl.where(probs_blk > p_pivot_0, - probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, - tl.min(masked_larger_0)) - - p_pivots_sum_0 += tl.sum(probs_blk * - (probs_blk > p_pivot_0)) + logits_blk = tl.load(search_addr + offs_n, + mask=mask_n, + other=-float('inf')) - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(search_address + offs_n, - mask=mask_n, - other=0.0) + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - num_min_larger_0 += tl.sum( - tl.abs(probs_blk - min_larger_0) < 1e-7) - # Check if any of the pivots are equal to k - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - else: - min_range = p_pivot_0 - else: - max_range = p_pivot_0 + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we update the range + elif k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: - p_pivot = p_pivot_0 - - # Force remove duplicates (p_pivot is made to include all - # duplicates if it falls on the duplicates) - num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, - tl.uint32) - force_remove_logit = tl.log( - min_larger_0 * sum_exp_logits) + max_logit - - # p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - p_pivot_logit = -float('inf') - - # -------- widen cutoff by 10% ---------------- - if p_pivot != -float('inf'): - p_pivot_logit = tl.log(p_pivot * sum_exp_logits) + max_logit - # current_num_force_remove = tl.zeros((), dtype=tl.uint32) - kept_write_pos = 0 - - for i in range(0, NUM_TILES): + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-8: + k_pivot = k_pivot_0 + # ============== Third pass : Apply top-k mask ================ + write_pos = tl.zeros((), dtype=tl.int32) + # if k_pivot != -float('inf'): + for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) - probs_blk = tl.load(PROBS_ROW + offs_n, - mask=mask_n, - other=0.0) - - # if force_remove_logit != -float('inf'): - # # Force remove duplicates - # tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) - # force_remove_mask = tl.abs( - # logits_blk - force_remove_logit) < tolerance - # force_remove_count = tl.cumsum(force_remove_mask) + current_num_force_remove - # force_remove_count_mask = force_remove_count <= num_force_remove - # force_remove_mask = force_remove_count_mask & force_remove_mask - # logits_blk = tl.where(force_remove_mask, -float('inf'), logits_blk) - # current_num_force_remove = tl.max(force_remove_count) - - keep_mask = logits_blk > p_pivot_logit - out_vals = tl.where(keep_mask, logits_blk, -float('inf')) - tl.store(LOGITS_ROW + offs_n, out_vals, mask=mask_n) - - # ====== keeping track of L and idx_tensor ====== + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + keep_mask = (logits_blk >= k_pivot) & mask_n n_kept = tl.sum(keep_mask, dtype=tl.int32) - if n_kept > 0: - cpos = tl.cast(tl.cumsum(keep_mask) - 1 + kept_write_pos, tl.int32) - write_idx = tl.where(keep_mask, cpos, 0) + cpos = tl.cumsum(keep_mask) -1 + write_pos + final_mask = keep_mask & (cpos < k) + write_idx = tl.where(final_mask, cpos, 0) + tl.store(OUTPUT_LOGITS_ROW + write_idx, logits_blk, mask=final_mask) + tl.store(OUTPUT_INDICES_ROW + write_idx, offs_n, mask=final_mask) + write_pos += tl.sum(final_mask, dtype=tl.int32) + # for temporary debugging + tl.store(DEBUG_K_PIVOT + row_id, k_pivot) + tl.store(DEBUG_WRITE_POS + row_id, write_pos) + tl.store(DEBUG_NUM_OUTLIERS + row_id, num_outliers) - tl.store(IDX_ROW + write_idx, offs_n, mask=keep_mask) - tl.store(PROBS_2_ROW + write_idx, probs_blk, mask=keep_mask) - kept_write_pos += n_kept - tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) - tl.store(MIN_IDX + row_id, tl.cast(min_logit_idx, tl.int32)) - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n >= kept_write_pos - tl.store(IDX_ROW + offs_n, min_logit_idx, mask=mask_n) - def apply_top_p_filtered ( - logits: torch.Tensor, + logits: torch.Tensor, p: torch.Tensor, - debug: bool = False -) -> torch.Tensor: +) -> torch.Tensor: """ - Applies top-p using pivot-based filtering + Applies top p using pivot based filtering """ - - logits_copy = logits.clone().detach() batch_size, vocab_size = logits.shape - - probs = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) - probs_2 = torch.zeros((batch_size, vocab_size), device=logits.device) - l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) - min_idx = torch.zeros_like(l, dtype=torch.int32) - idx_tensor = torch.full_like(logits, -1, dtype=torch.int32) - BLOCK_SIZE = 2048 SIGMA = 2.15 NUM_WARPS = 16 NUM_STAGES = 3 - # if not torch.any(p < 1.0): - # return logits - - p_widened = torch.clamp(p * 1.2, max=0.999) - # p_widened = torch.ones_like(p) - - grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) + if not torch.any(p < 1.0): + return logits + + # ================= to find the k filter value ==================== + max_p = p.max().item() + if max_p > 0.9: + k_filter = 6 * 1000 + elif max_p > 0.8: + k_filter = 5 * 1000 + elif max_p > 0.7: + k_filter = 4 * 1000 + elif max_p > 0.6: + k_filter = 4 * 1000 + elif max_p > 0.5: + k_filter = 4 * 1000 + elif max_p > 0.4: + k_filter = 4 * 1000 + else: + k_filter = 4 * 1000 + + k_filter = min(k_filter, vocab_size) + + filtered_logits = torch.full((batch_size, k_filter), -float('inf'), device=logits.device) + filtered_indices = torch.full((batch_size, k_filter), 0, dtype=torch.int32, device=logits.device) + + debug_k_pivot = torch.full((batch_size,), -float('inf'), dtype=torch.float32, device=logits.device) + debug_write_pos = torch.zeros((batch_size,), dtype=torch.int32, device=logits.device) + debug_num_outliers = torch.zeros((batch_size,), dtype=torch.int32, device=logits.device) + + probs = torch.empty((batch_size, vocab_size), device=logits.device, dtype=torch.float32) + probs_idx = torch.empty_like(probs, dtype=torch.int32) + + grid = (batch_size,) top_p_pivot_filter[grid]( - logits, - l, - min_idx, - probs, - probs_2, - idx_tensor, - p_widened, - batch_size, - SIGMA=SIGMA, + logits, # --> input + probs, # initial filtered + probs_idx, # initial filtered index + k_filter, # --> scalar + filtered_logits, # --> output, filtered + filtered_indices, # --> filtered logits indices + batch_size, + debug_k_pivot, + debug_write_pos, + debug_num_outliers, + SIGMA=SIGMA, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, - ) - - max_l = torch.max(l) - error_row = 859 if batch_size > 1000 else 2 - print(f"p = {p[error_row]}") - - # if max_l.item() == 0: - # return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) - - outliers = idx_tensor[:, :max_l] - - #this is for gather, which puts in garbage value for those not included. I wanted to mask it out. - valid_mask = torch.arange(max_l, device=logits.device).unsqueeze(0) < l.unsqueeze(1) - - print(f"\n=== STEP 1: After gathering outliers ===") - print(f"outliers[{error_row-2}:{error_row+2}, :10]: {outliers[error_row-2:error_row+2, :10]}") - print(f"outliers[{error_row-2}:{error_row+2}, -10:]: {outliers[error_row-2:error_row+2, -10:]}") - - full_probs = logits_copy.softmax(dim=-1) + ) + print(f"k pivot: {debug_k_pivot}") + print(f"write pos: {debug_write_pos}") + print(f"filtered logits: {filtered_logits[0:10]}") + # this kernel outputs filtered_logits and filtered_indices of shape (batch_size, k_filter) + logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) + logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) - filtered_full_probs = torch.gather(full_probs, 1, outliers) - filtered_logits = torch.gather(logits, 1, outliers) + logits_softmax = logits.softmax(dim=-1) + sorted_probs = torch.gather(logits_softmax, -1, logits_sort_indices) - filtered_probs_sort = torch.where(valid_mask, filtered_full_probs, 0.0) - filtered_logits = torch.where(valid_mask, filtered_logits, torch.tensor(-float('inf'), device=logits.device)) + sum_probs = sorted_probs.sum(dim=-1) - print(f"\n=== STEP 2: After gathering logits ===") - print(f"filtered_logits[{error_row-2}:{error_row+2}, :10]: {filtered_logits[error_row-2:error_row+2, :10]}") - print(f"filtered_logits[{error_row-2}:{error_row+2}, -10:]: {filtered_logits[error_row-2:error_row+2, -10:]}") + # ========================== Debugging =========================================== + print("filtered_logits[0,:10]:", filtered_logits[0,:10]) + print("logits sort: ", logits_sort[0, 10]) - non_outliers_mask = torch.ones_like(full_probs, dtype=torch.bool) - non_outliers_mask.scatter_(1, outliers, False) # False = pivoted tokens - sum_non_outliers = full_probs.masked_fill(~non_outliers_mask, 0.0).sum(dim=1) - print (f"min index shape = {min_idx.shape}") - pytorch_min_prob = torch.min(full_probs, dim=1).values - print (f"pytorch min prob = {pytorch_min_prob[error_row-2:error_row+2]}") + print("logits_sort_indices[0,:10]:", logits_sort_indices[0,:10]) - sum_non_outliers += pytorch_min_prob + print("sorted probs:", sorted_probs[0,:10]) + # if torch.any(sum_probs < p): + # print("edge case..........") + # return apply_top_k_top_p(logits, k=None, p=p) - print (f"full_probs = {full_probs.sum(dim=1)[error_row-2:error_row+2]}") - print (f"sum non outliers = {sum_non_outliers[error_row-2:error_row+2]}") - print (f"min index probs = {full_probs[error_row-2:error_row+2, min_idx[error_row-2:error_row+2]]}") + probs_sum = torch.cumsum(sorted_probs, dim=-1) + sum_non_outliers = (1.0 - sum_probs).unsqueeze(-1) + probs_sum = probs_sum + sum_non_outliers + top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=-1)) + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # if vocab_size > 4000: - # print (f" min index probs = {full_probs[error_row-2:error_row+2, 1975]}") + logits.fill_(-float("inf")) + logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) + return logits - filtered_logits_sort, sort_indices = torch.sort( - filtered_logits, dim=-1, descending=False - ) - print(f"\n=== STEP 3: After sorting ===") - print(f"filtered_logits_sort[{error_row-2}:{error_row+2}, -10:]: {filtered_logits_sort[error_row-2:error_row+2, -10:]}") - - outliers_sorted = torch.gather(outliers, 1, sort_indices) - - print(f"\n=== STEP 4: outliers_sorted ===") - print(f"outliers_sorted[{error_row-2}:{error_row+2}, :1]: {outliers_sorted[error_row-2:error_row+2, :10]}") - - filtered_probs_sort = torch.gather(filtered_probs_sort, 1, sort_indices) - valid_mask_sorted = torch.gather(valid_mask, 1, sort_indices) - - probs_sum = sum_non_outliers.unsqueeze(1) + torch.cumsum(filtered_probs_sort, dim=-1) - print(f"========== probs sum ============= : {probs_sum[error_row-2:error_row+2, -10:]}") - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - top_p_mask = top_p_mask | ~valid_mask_sorted - top_p_mask[:, -1] = False - print (f"top p mask {top_p_mask[error_row-2:error_row+2, -10:]}") - filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) - - logits.fill_(-float("inf")) - logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) - - - - # if True: - # print(f"\n=== DEBUG ROW 11 ===") - # print(f"l[11]: {l[11]}") - # print(f"valid tokens count: {valid_mask[11].sum()}") - # print(f"outliers_sorted[11, -10:]: {outliers_sorted[11, -10:]}") - # print(f"filtered_logits_sort[11, -10:]: {filtered_logits_sort[11, -10:]}") - # print(f"filtered_probs_sort[11, -10:]: {filtered_probs_sort[11, -10:]}") - # print(f"probs_sum[11, -10:]: {probs_sum[11, -10:]}") - # print(f"top_p_mask[11, -10:]: {top_p_mask[11, -10:]}") - # print(f"valid_mask_sorted[11, -10:]: {valid_mask_sorted[11, -10:]}") - - # print(f"\nBefore scatter, logits[11, :10]: {logits[11, :10]}") + # @triton.jit +# def top_p_pivot_filter(LOGITS, L, MIN_IDX, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr): +# NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE +# pid = tl.program_id(0) +# num_programs = tl.num_programs(0) + +# for row_id in tl.range(pid, B, num_programs): +# p = tl.load(P + row_id) # fetches the p value of the row it is working on +# # if p != 1.0: # if p == 1, this becomes pointless ! +# if True: +# p_pivot = -float('inf') + +# LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE +# PROBS_ROW = PROBS + row_id * VOCAB_SIZE +# PROBS_2_ROW = PROBS_2 + row_id * VOCAB_SIZE +# IDX_ROW = idx_tensor + row_id * VOCAB_SIZE + +# search_address = PROBS_ROW +# search_range = VOCAB_SIZE +# search_iters = NUM_TILES + +# max_logit = -float('inf') +# min_logit = float('inf') +# min_logit_idx = -1 + +# force_remove_logit = -float('inf') # for handling duplicate cases (edge case) +# num_force_remove = tl.zeros((), dtype=tl.uint32) + +# # First Pass: Compute avg and std from a sample block +# offs = tl.arange(0, BLOCK_SIZE) +# mask_n = offs < VOCAB_SIZE +# logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) +# avg_logit = tl.sum(logits_blk) / VOCAB_SIZE +# sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE +# std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + +# outlier_pivot = avg_logit + SIGMA * std_logit +# num_outliers = tl.zeros((), dtype=tl.uint32) +# sum_outlier_probs = 0.0 +# sum_exp_logits = 0.0 + +# # ====== Second Pass: compute max and min logits ====== +# for i in range(0, search_iters): +# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# mask_n = offs_n < search_range +# logits_blk = tl.load(LOGITS_ROW + offs_n, +# mask=mask_n, +# other=avg_logit) +# max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + +# for i in range(0, search_iters): +# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# mask_n = offs_n < search_range +# logits_blk = tl.load(LOGITS_ROW + offs_n, +# mask=mask_n, +# other=float('inf')) +# local_min_logit = tl.min(logits_blk) +# local_min_logit_idx = tl.argmin(logits_blk, axis=0) +# if local_min_logit < min_logit: +# min_logit = local_min_logit +# min_logit_idx = local_min_logit_idx + i * BLOCK_SIZE + +# # ====== Third pass: Calculate exp logits and sum ====== +# for i in range(0, search_iters): +# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# mask_n = offs_n < search_range + +# probs_blk = tl.load(LOGITS_ROW + offs_n, +# mask=mask_n, +# other=-float('inf')) +# probs_blk = probs_blk - max_logit +# probs_blk = tl.exp(probs_blk) +# sum_exp_logits += tl.sum(probs_blk) +# tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) +# outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits + +# # ====== Fourth pass: Calculate probs and get outliers ====== +# for i in range(0, search_iters): +# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# mask_n = offs_n < search_range + +# probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) +# probs_blk = probs_blk / sum_exp_logits +# tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + +# outlier_mask = (probs_blk > outlier_prob) & mask_n +# sum_outlier_probs += tl.sum(outlier_mask * probs_blk) +# num_blk_outliers = tl.sum(outlier_mask) +# cumulative_pos = tl.cast( +# tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) +# num_outliers += num_blk_outliers +# write_pos = tl.where(outlier_mask, cumulative_pos, -1) +# tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) # stores the final probs after masking to PROBS_2 + +# max_range = tl.exp(max_logit - max_logit) / sum_exp_logits +# min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + +# if sum_outlier_probs > p: +# min_range = outlier_prob +# search_address = PROBS_2_ROW +# search_range = tl.cast(num_outliers, tl.int32) +# search_iters = tl.cast( +# (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + +# second_max_logit = -float('inf') + +# num_iters = 0 +# p_pivots_sum_0 = 0.0 +# min_larger_0 = 1.0 +# num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + +# # ====== Fifth Passes: Search for p_pivot(2log_2(n)) ====== +# while p_pivot == -float('inf') and num_iters < 32: +# p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range +# p_pivots_sum_0 = 0.0 + +# min_larger_0 = 1.0 +# num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + +# for i in range(0, search_iters): +# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# mask_n = offs_n < search_range +# probs_blk = tl.load(search_address + offs_n, +# mask=mask_n, +# other=0.0) + +# masked_larger_0 = tl.where(probs_blk > p_pivot_0, +# probs_blk, 1.0) +# min_larger_0 = tl.minimum(min_larger_0, +# tl.min(masked_larger_0)) + +# p_pivots_sum_0 += tl.sum(probs_blk * +# (probs_blk > p_pivot_0)) + +# for i in range(0, search_iters): +# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# mask_n = offs_n < search_range +# probs_blk = tl.load(search_address + offs_n, +# mask=mask_n, +# other=0.0) + +# num_min_larger_0 += tl.sum( +# tl.abs(probs_blk - min_larger_0) < 1e-7) + +# # Check if any of the pivots are equal to k +# if p_pivots_sum_0 >= p: +# if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: +# p_pivot = p_pivot_0 +# else: +# min_range = p_pivot_0 +# else: +# max_range = p_pivot_0 + +# num_iters += 1 +# if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: +# p_pivot = p_pivot_0 + +# # Force remove duplicates (p_pivot is made to include all +# # duplicates if it falls on the duplicates) +# num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, +# tl.uint32) +# force_remove_logit = tl.log( +# min_larger_0 * sum_exp_logits) + max_logit + +# # p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit +# p_pivot_logit = -float('inf') + +# # -------- widen cutoff by 10% ---------------- +# if p_pivot != -float('inf'): +# p_pivot_logit = tl.log(p_pivot * sum_exp_logits) + max_logit +# kept_write_pos = 0 + +# for i in range(0, NUM_TILES): +# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# mask_n = offs_n < VOCAB_SIZE +# # padding all idx with min_logit_idx to avoid -1 or 0 being the padding value +# tl.store(IDX_ROW + offs_n, min_logit_idx, mask=mask_n) + +# logits_blk = tl.load(LOGITS_ROW + offs_n, +# mask=mask_n, +# other=-float('inf')) +# probs_blk = tl.load(PROBS_ROW + offs_n, +# mask=mask_n, +# other=0.0) + +# keep_mask = logits_blk > p_pivot_logit +# out_vals = tl.where(keep_mask, logits_blk, -float('inf')) +# tl.store(LOGITS_ROW + offs_n, out_vals, mask=mask_n) + +# # ====== keeping track of L and idx_tensor ====== +# n_kept = tl.sum(keep_mask, dtype=tl.int32) +# if n_kept > 0: +# cpos = tl.cast(tl.cumsum(keep_mask) - 1 + kept_write_pos, tl.int32) +# write_idx = tl.where(keep_mask, cpos, 0) + +# tl.store(IDX_ROW + write_idx, offs_n, mask=keep_mask) +# tl.store(PROBS_2_ROW + write_idx, probs_blk, mask=keep_mask) + +# kept_write_pos += n_kept +# tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) +# tl.store(MIN_IDX + row_id, tl.cast(min_logit_idx, tl.int32)) + +# @triton.jit +# def apply_mask( +# filtered_logits_sort, +# temp_logits, +# filtered_probs_sort, +# temp_probs, +# sort_indices, +# outliers, +# outliers_sorted, +# B, +# MAX_L, +# BLOCK_SIZE: tl.constexpr, +# ): +# pid = tl.program_id(0) +# num_programs = tl.num_programs(0) # total number of programs launched + +# for row_id in tl.range(pid, B, num_programs): # each row +# SORTED_LOGITS_ROW = filtered_logits_sort + row_id * MAX_L +# SORTED_PROBS_ROW = filtered_probs_sort + row_id * MAX_L +# SORTED_IDX_ROW = sort_indices + row_id * MAX_L +# OUTLIERS_ROW = outliers + row_id * MAX_L + +# OUTPUT_LOGITS_ROW = temp_logits + row_id * MAX_L +# OUTPUT_PROBS_ROW = temp_probs + row_id * MAX_L +# OUTPUT_OUTLIERS_SORTED_ROW = outliers_sorted + row_id * MAX_L + +# num_tiles = (MAX_L + BLOCK_SIZE - 1) // BLOCK_SIZE # for copying - # non_inf_mask = logits[11] != -float('inf') - # non_inf_indices = torch.where(non_inf_mask)[0] - # print(f"After scatter, non-inf indices: {non_inf_indices}") - # print(f"After scatter, non-inf values: {logits[11, non_inf_indices]}") - return logits +# for i in range(0, num_tiles): # partitioning the row +# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# mask_n = offs_n < MAX_L + +# sort_idx_blk = tl.load(SORTED_IDX_ROW + offs_n, mask=mask_n, other=0) +# probs_blk = tl.load(SORTED_PROBS_ROW + sort_idx_blk, mask=mask_n, other=0.0) +# logits_blk = tl.load(SORTED_LOGITS_ROW + sort_idx_blk, mask=mask_n, other=-float('inf')) +# outliers_blk = tl.load(OUTLIERS_ROW + sort_idx_blk, mask=mask_n, other=0) +# tl.store(OUTPUT_PROBS_ROW + offs_n, probs_blk, mask=mask_n) +# tl.store(OUTPUT_LOGITS_ROW + offs_n, logits_blk, mask=mask_n) +# tl.store(OUTPUT_OUTLIERS_SORTED_ROW + offs_n, outliers_blk, mask=mask_n) -def apply_top_k_top_p_test2( - logits: torch.Tensor, - k: torch.Tensor | None, - p: torch.Tensor | None, - debug: bool = False -) -> torch.Tensor: - """ - Uses pivot-based algorithm to filter --> sort - """ - if p.max() > 0.99: - return apply_top_k_top_p(logits, None, p) - if k is None and p is None: - return logits - elif p is None and k is not None: - return apply_top_k_only(logits, k) - elif k is None and p is not None: - return apply_top_p_filtered(logits, p) - else: - logits_k = apply_top_k_only(logits, k) - return apply_top_p_filtered(logits_k, p, debug) - # def apply_top_p_filtered ( # logits: torch.Tensor, # p: torch.Tensor, +# debug: bool = False # ) -> torch.Tensor: # """ # Applies top-p using pivot-based filtering # """ -# # logits = torch.ones((10,), device=logits.device, dtype=torch.float32).view(1, -1) # logits_copy = logits.clone().detach() -# # p = torch.full((logits.shape[0],), 0.65, dtype=torch.float32, device=logits.device) -# # output = apply_top_k_top_p(logits_copy, None, p) -# # print(f"original value = {output}") # batch_size, vocab_size = logits.shape -# # print(f"logits: {logits}", flush=True) # probs = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) -# probs_2 = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) +# probs_2 = torch.zeros((batch_size, vocab_size), device=logits.device) # l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) -# idx_tensor = torch.full_like(logits, 0, dtype=torch.int32) +# min_idx = torch.zeros_like(l, dtype=torch.int32) +# idx_tensor = torch.full_like(logits, -1, dtype=torch.int32) - # BLOCK_SIZE = 2048 # SIGMA = 2.15 # NUM_WARPS = 16 # NUM_STAGES = 3 -# WIDEN_NUM = 0 # ----------------------------> ???????? # if not torch.any(p < 1.0): # return logits +# p_widened = torch.clamp(p*1.01, max=0.999) + # grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) # top_p_pivot_filter[grid]( # logits, # l, +# min_idx, # probs, # probs_2, # idx_tensor, -# p*1.1, +# p_widened, # batch_size, # SIGMA=SIGMA, # VOCAB_SIZE=vocab_size, # BLOCK_SIZE=BLOCK_SIZE, -# WIDEN_NUM=WIDEN_NUM # ) -# # print(f"logits: {logits}", flush=True) -# # print(f"l: {l}", flush=True) -# # print(f"probs: {probs}", flush=True) -# # print(f"probs_2: {probs_2}", flush=True) -# # print(f"idx_tensor: {idx_tensor}", flush=True) - -# max_l = torch.max(l) -# # if max_l.item() == 0: -# # return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) - - -# # outliers = torch.arange(0, vocab_size, dtype=torch.int32, device=logits.device).unsqueeze(0).expand(logits.shape[0], -1) -# # probs = torch.softmax(logits, dim=1) +# max_l = torch.max(l).item() +# if max_l == 0: +# return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) + # outliers = idx_tensor[:, :max_l] -# full_probs = logits_copy.softmax(dim=-1) -# filtered_full_probs = torch.gather(full_probs, 1, outliers) -# filtered_logits = torch.gather(logits, 1, outliers) -# # print(f"outliers: {outliers}", flush=True) -# # print(f"filtered_logits: {filtered_logits}", flush=True) - -# probs = torch.gather(probs, 1, outliers) -# cum_sum = torch.sum(probs, dim=1) -# # print(f"sum : {cum_sum}", flush=True) +# # this is for gather, which puts in garbage value for those not included. I wanted to mask it out. +# valid_mask = torch.arange(max_l, device=logits.device).unsqueeze(0) < l.unsqueeze(1) +# filtered_full_probs = probs_2[:, :max_l] +# filtered_logits = torch.gather(logits, 1, outliers) + +# filtered_probs_sort = torch.where(valid_mask, filtered_full_probs, 0.0) +# filtered_logits = torch.where(valid_mask, filtered_logits, torch.tensor(-float('inf'), device=logits.device)) +# sum_outliers = (filtered_full_probs * valid_mask.float()).sum(dim=1) +# sum_non_outliers = 1.0 - sum_outliers +# # we must do sort in python, unfortunately..... # filtered_logits_sort, sort_indices = torch.sort( # filtered_logits, dim=-1, descending=False # ) +# # outliers_sorted = torch.gather(outliers, 1, sort_indices) +# # filtered_probs_sort = torch.gather(filtered_probs_sort, 1, sort_indices) + +# # =================== calling a new triton kernel =============================== +# temp_logits = torch.empty_like(filtered_logits_sort) +# temp_probs = torch.empty_like(filtered_probs_sort) +# outliers_sorted = torch.empty_like(outliers, dtype=torch.int32) + +# grid = (batch_size, ) +# apply_mask[grid]( +# filtered_logits_sort, +# temp_logits, +# filtered_probs_sort, +# temp_probs, +# sort_indices, +# outliers, +# outliers_sorted, +# batch_size, +# max_l, +# BLOCK_SIZE=BLOCK_SIZE, +# ) +# valid_mask_sorted = torch.gather(valid_mask, 1, sort_indices) - -# # print(f"filtered_logits_sort: {filtered_logits_sort}", flush=True) - -# outliers_sorted = torch.gather(outliers, 1, sort_indices) -# filtered_probs_sort = torch.gather(filtered_full_probs, 1, sort_indices) - -# # print(f"outliers_sorted: {outliers_sorted}", flush=True) -# # print(f"filtered_probs_sort: {filtered_logits_sort}", flush=True) - - -# probs_sum = torch.cumsum(filtered_probs_sort, dim=-1, out=filtered_probs_sort) +# probs_sum = sum_non_outliers.unsqueeze(1) + torch.cumsum(temp_probs, dim=-1) # top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) +# top_p_mask = top_p_mask | ~valid_mask_sorted # top_p_mask[:, -1] = False -# # print(f"probs_sum: {probs_sum}") -# # print(f"top_p_mask = {top_p_mask}") +# temp_logits.masked_fill_(top_p_mask, -float("inf")) -# filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) - -# # print(f"top_p_mask = {filtered_logits_sort}") - +# # outliers_sorted = torch.gather(outliers, 1, sort_indices) # logits.fill_(-float("inf")) -# logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) +# logits.scatter_(dim=1, index=outliers_sorted, src=temp_logits) -# # print(f"final logits after scatter = {logits}") - -# # assert False # return logits +# # valid_mask_sorted = torch.gather(valid_mask, 1, sort_indices) + +# # probs_sum = sum_non_outliers.unsqueeze(1) + torch.cumsum(filtered_probs_sort, dim=-1) +# # top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) +# # top_p_mask = top_p_mask | ~valid_mask_sorted +# # top_p_mask[:, -1] = False +# # filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) -# def apply_top_k_top_p_test2( -# logits: torch.Tensor, -# k: torch.Tensor | None, -# p: torch.Tensor | None, -# ) -> torch.Tensor: -# """ -# Uses pivot-based algorithm to filter --> sort -# """ -# if k is None and p is None: -# return logits -# elif p is None and k is not None: -# return apply_top_k_only(logits, k) -# elif k is None and p is not None: -# return apply_top_p_filtered(logits, p) -# else: -# logits_k = apply_top_k_only(logits, k) -# return apply_top_p_filtered(logits, p) \ No newline at end of file +# # outliers_sorted = torch.gather(outliers, 1, sort_indices) +# # logits.fill_(-float("inf")) +# # logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) + +# # return logits + +def apply_top_k_top_p_test2( + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + debug: bool = False +) -> torch.Tensor: + """ + Uses pivot-based algorithm to filter --> sort + """ + if p.max() > 0.99: + return apply_top_k_top_p(logits, None, p) + if k is None and p is None: + return logits + elif p is None and k is not None: + return apply_top_k_only(logits, k) + elif k is None and p is not None: + return apply_top_p_filtered(logits, p) + else: + logits_k = apply_top_k_only(logits, k) + return apply_top_p_filtered(logits_k, p, debug) + \ No newline at end of file From d1ca674ffebce336b75c9d2502d49ec24d7fb4b8 Mon Sep 17 00:00:00 2001 From: Sunga Kim Date: Thu, 13 Nov 2025 20:36:24 -0800 Subject: [PATCH 47/99] pushed? --- compare.py | 9 ++-- vllm/v1/sample/ops/topk_topp_sampler.py | 56 ++++--------------------- 2 files changed, 13 insertions(+), 52 deletions(-) diff --git a/compare.py b/compare.py index e7586acd48ec..5336840e98ae 100644 --- a/compare.py +++ b/compare.py @@ -88,9 +88,9 @@ def test_accuracy(logits, k, p, func_list): def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): # We must clone the logits for each run to avoid modifying the original - warmup_tensor = logits.clone().detach() + warmup_tensor = [logits.clone().detach() for _ in range(num_warmup)] for _ in range(num_warmup): - test_func(warmup_tensor, k, p) + test_func(warmup_tensor[_], k, p) torch.cuda.synchronize() input_logits = [logits.clone().detach() for _ in range(num_runs)] @@ -112,8 +112,8 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): date_str = datetime.now().strftime("%Y%m%d_%H%M%S") batch_size_list = [64, 128, 1024] - vocab_size_list = [4096, 16384] - p_list = [None, 0.4, 0.7, 0.9, 0.95, 0.99] + vocab_size_list = [4096, 16384, 65536] + p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] # k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] k_list = [None] func_list = [apply_top_k_top_p, apply_top_k_top_p_test2] @@ -175,6 +175,7 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): f" batch_size: {batch_size}," f" vocab_size: {vocab_size}, dist_generator: " f"{dist_generator}, p: {p}, k: {k}", log_file) + print_to_log(f"Test accuracy passed! Now testing speedup...", log_file) time_list = [] for func in func_list: time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 3e257b9dacdb..a2f4c35793dd 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1245,9 +1245,6 @@ def top_p_pivot_filter( OUTPUT_LOGITS, OUTPUT_INDICES, B, # --> batch size - DEBUG_K_PIVOT, - DEBUG_WRITE_POS, - DEBUG_NUM_OUTLIERS, SIGMA: tl.constexpr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -1321,7 +1318,6 @@ def top_p_pivot_filter( # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 - tl.store(OUTPUT_LOGITS_ROW, 12345.0) while k_pivot == -float('inf') and num_iters < 18: k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range @@ -1380,10 +1376,7 @@ def top_p_pivot_filter( tl.store(OUTPUT_LOGITS_ROW + write_idx, logits_blk, mask=final_mask) tl.store(OUTPUT_INDICES_ROW + write_idx, offs_n, mask=final_mask) write_pos += tl.sum(final_mask, dtype=tl.int32) - # for temporary debugging - tl.store(DEBUG_K_PIVOT + row_id, k_pivot) - tl.store(DEBUG_WRITE_POS + row_id, write_pos) - tl.store(DEBUG_NUM_OUTLIERS + row_id, num_outliers) + def apply_top_p_filtered ( @@ -1402,33 +1395,13 @@ def apply_top_p_filtered ( if not torch.any(p < 1.0): return logits - + # ================= to find the k filter value ==================== - max_p = p.max().item() - if max_p > 0.9: - k_filter = 6 * 1000 - elif max_p > 0.8: - k_filter = 5 * 1000 - elif max_p > 0.7: - k_filter = 4 * 1000 - elif max_p > 0.6: - k_filter = 4 * 1000 - elif max_p > 0.5: - k_filter = 4 * 1000 - elif max_p > 0.4: - k_filter = 4 * 1000 - else: - k_filter = 4 * 1000 - - k_filter = min(k_filter, vocab_size) + k_filter = int(vocab_size * 1/32) filtered_logits = torch.full((batch_size, k_filter), -float('inf'), device=logits.device) filtered_indices = torch.full((batch_size, k_filter), 0, dtype=torch.int32, device=logits.device) - debug_k_pivot = torch.full((batch_size,), -float('inf'), dtype=torch.float32, device=logits.device) - debug_write_pos = torch.zeros((batch_size,), dtype=torch.int32, device=logits.device) - debug_num_outliers = torch.zeros((batch_size,), dtype=torch.int32, device=logits.device) - probs = torch.empty((batch_size, vocab_size), device=logits.device, dtype=torch.float32) probs_idx = torch.empty_like(probs, dtype=torch.int32) @@ -1441,16 +1414,11 @@ def apply_top_p_filtered ( filtered_logits, # --> output, filtered filtered_indices, # --> filtered logits indices batch_size, - debug_k_pivot, - debug_write_pos, - debug_num_outliers, SIGMA=SIGMA, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, ) - print(f"k pivot: {debug_k_pivot}") - print(f"write pos: {debug_write_pos}") - print(f"filtered logits: {filtered_logits[0:10]}") + # this kernel outputs filtered_logits and filtered_indices of shape (batch_size, k_filter) logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) @@ -1460,18 +1428,10 @@ def apply_top_p_filtered ( sum_probs = sorted_probs.sum(dim=-1) - # ========================== Debugging =========================================== - print("filtered_logits[0,:10]:", filtered_logits[0,:10]) - - print("logits sort: ", logits_sort[0, 10]) - - print("logits_sort_indices[0,:10]:", logits_sort_indices[0,:10]) - - print("sorted probs:", sorted_probs[0,:10]) - - # if torch.any(sum_probs < p): - # print("edge case..........") - # return apply_top_k_top_p(logits, k=None, p=p) + if torch.any(sum_probs < p): + print(f"edge case --> fall back !") + assert False + return apply_top_k_top_p(logits, k=None, p=p) probs_sum = torch.cumsum(sorted_probs, dim=-1) sum_non_outliers = (1.0 - sum_probs).unsqueeze(-1) From 5697d83e47898823198ec5a6af933b9d25fdf9f4 Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 14 Nov 2025 03:09:41 -0800 Subject: [PATCH 48/99] Top-k working Signed-off-by: js_park --- compare.py | 135 +-- vllm/v1/sample/ops/topk_topp_sampler.py | 1143 +++++++++-------------- 2 files changed, 512 insertions(+), 766 deletions(-) diff --git a/compare.py b/compare.py index 5336840e98ae..2ab256a4b9ad 100644 --- a/compare.py +++ b/compare.py @@ -2,26 +2,23 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from datetime import datetime from itertools import product -import torch - -torch.manual_seed(42) -torch.cuda.manual_seed(42) -torch.cuda.manual_seed_all(42) -import random -import numpy as np -random.seed(42) -np.random.seed(42) -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False +# torch.manual_seed(42) +# torch.cuda.manual_seed(42) +# torch.cuda.manual_seed_all(42) +# import random +# import numpy as np +# random.seed(42) +# np.random.seed(42) +# torch.backends.cudnn.deterministic = True +# torch.backends.cudnn.benchmark = False import regex as re -import torch - -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - apply_top_k_top_p_triton, - apply_top_k_top_p_test2 - - ) +import torch + +from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, + apply_top_k_top_p_test2, +) def g_str(s): @@ -56,30 +53,37 @@ def test_accuracy(logits, k, p, func_list): output_logits = func_list[i](input_logit_list[i], k, p) torch.cuda.synchronize() - is_correct = torch.allclose(original_logits, output_logits) + original_logits_bin = original_logits.view(torch.int32) + output_logits_bin = output_logits.view(torch.int32) + is_correct = torch.all(original_logits_bin == output_logits_bin) output_correct_list.append(is_correct) func_name = func_list[i].__name__ if not is_correct: - print_to_log(r_str(f"Error: logits are not close on {i} - " + f"{func_name}"), log_file) + print_to_log( + r_str(f"Error: logits are not close on {i} - " + f"{func_name}"), + log_file, + ) output_logits = apply_top_k_top_p_test2(logits, k, p, debug=True) - error_mask = torch.abs(output_logits - original_logits) > 1e-5 + error_mask = torch.abs(output_logits - original_logits) > 1e-16 error_rows = torch.where(error_mask)[0] error_rows = torch.unique(error_rows) num_error_rows = error_rows.shape[0] error_cols = torch.where(error_mask)[1] error_cols = torch.unique(error_cols) num_error_cols = error_cols.shape[0] - print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", - log_file) + print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", log_file) print_to_log(f"num_error_cols: {num_error_cols}", log_file) row_to_show = 5 if num_error_rows > 5 else num_error_rows - logits_to_show = torch.sort(output_logits[error_rows], - descending=True).values + logits_to_show = torch.sort( + output_logits[error_rows], descending=True + ).values + logits_to_show = logits_to_show[:row_to_show, :20] print_to_log(f"logits: {logits_to_show}", log_file) - original_logits_to_show = \ - torch.sort(original_logits[error_rows], descending=True).values + original_logits_to_show = torch.sort( + original_logits[error_rows], descending=True + ).values original_logits_to_show = original_logits_to_show[:row_to_show, :20] print_to_log(f"original_logits: {original_logits_to_show}", log_file) assert False @@ -112,10 +116,10 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): date_str = datetime.now().strftime("%Y%m%d_%H%M%S") batch_size_list = [64, 128, 1024] - vocab_size_list = [4096, 16384, 65536] - p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] - # k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] - k_list = [None] + vocab_size_list = [4096, 16384, 65536, 128000, 262144] + # p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] + p_list = [None] + k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] func_list = [apply_top_k_top_p, apply_top_k_top_p_test2] log_file = f"triton_topk_topp_test_{date_str}.log" @@ -131,12 +135,14 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) with open(csv_file, "w") as f: - f.write("dist_generator,batch_size,vocab_size,p,k,triton_correct,test_correct" - "torch_time_taken,triton_time_taken,test_time_taken,triton_speedup,test_speedup\n") + f.write( + "dist_generator,batch_size,vocab_size,p,k,triton_correct,test_correct" + "torch_time_taken,triton_time_taken,test_time_taken,triton_speedup,test_speedup\n" + ) - for batch_size, vocab_size, p, k in product(batch_size_list, - vocab_size_list, p_list, - k_list): + for batch_size, vocab_size, p, k in product( + batch_size_list, vocab_size_list, p_list, k_list + ): if p is None and k is None: continue @@ -144,28 +150,34 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): logits_list = [("RANDN", logits_randn)] if p == "RAND": - p_tensor = torch.rand((batch_size, ), device="cuda") * 0.95 + 0.05 + p_tensor = torch.rand((batch_size,), device="cuda") * 0.95 + 0.05 elif p is not None: - p_tensor = torch.full((batch_size, ), p, device="cuda") + p_tensor = torch.full((batch_size,), p, device="cuda") else: p_tensor = None if k == "RAND": - k_tensor = torch.randint(1, - vocab_size, (batch_size, ), - device="cuda") + k_tensor = torch.randint(1, vocab_size, (batch_size,), device="cuda") elif k is not None: - k_tensor = torch.full((batch_size, ), k, device="cuda") + k_tensor = torch.full((batch_size,), k, device="cuda") else: k_tensor = None for dist_generator, logits in logits_list: print_to_log(y_str("--------------------------------"), log_file) print_to_log( - g_str("Testing ") + f"{dist_generator}" + - y_str(" with batch_size: ") + f"{batch_size}" + - y_str(" vocab_size: ") + f"{vocab_size}" + y_str(" p: ") + - f"{p}" + y_str(" k: ") + f"{k}", log_file) + g_str("Testing ") + + f"{dist_generator}" + + y_str(" with batch_size: ") + + f"{batch_size}" + + y_str(" vocab_size: ") + + f"{vocab_size}" + + y_str(" p: ") + + f"{p}" + + y_str(" k: ") + + f"{k}", + log_file, + ) correct_list = test_accuracy(logits, k_tensor, p_tensor, func_list) for i in range(len(func_list) - 1): is_correct = correct_list[i] @@ -174,29 +186,32 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): f"Error: logits are not close for function {func_list[i + 1].__name__}," f" batch_size: {batch_size}," f" vocab_size: {vocab_size}, dist_generator: " - f"{dist_generator}, p: {p}, k: {k}", log_file) - print_to_log(f"Test accuracy passed! Now testing speedup...", log_file) + f"{dist_generator}, p: {p}, k: {k}", + log_file, + ) + print_to_log("Test accuracy passed! Now testing speedup...", log_file) time_list = [] for func in func_list: time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) time_list.append(time_taken) - print_to_log( - b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) - print_to_log( - b_str("test_time_taken: ") + f"{time_list[1]}", - log_file) + print_to_log(b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) + print_to_log(b_str("test_time_taken: ") + f"{time_list[1]}", log_file) # print_to_log( # b_str("test_time_taken: ") + f"{time_list[2]}", log_file) print_to_log( - g_str("test Speedup over Torch: ") + - f"{time_list[0] / time_list[1]:.8f}x", log_file) + g_str("test Speedup over Torch: ") + + f"{time_list[0] / time_list[1]:.8f}x", + log_file, + ) # print_to_log( # y_str("Test Speedup over Torch: ") + - # f"{time_list[0] / time_list[2]:.8f}x", log_file) + # f"{time_list[0] / time_list[2]:.8f}x", log_file) with open(csv_file, "a") as f: - f.write(f"{dist_generator},{batch_size},{vocab_size},{p},{k}," - f"{correct_list[0]},{time_list[0]}," - f"{time_list[0] / time_list[1]:.8f}\n") + f.write( + f"{dist_generator},{batch_size},{vocab_size},{p},{k}," + f"{correct_list[0]},{time_list[0]}," + f"{time_list[0] / time_list[1]:.8f}\n" + ) print_to_log(y_str("--------------------------------\n"), log_file) """# SPDX-License-Identifier: Apache-2.0 @@ -421,4 +436,4 @@ def main(): print_to_log(y_str("--------------------------------\n"), log_file) if __name__ == "__main__": - main()""" \ No newline at end of file + main()""" diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index a2f4c35793dd..a48186982d7a 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -7,7 +7,6 @@ import triton import triton.language as tl from packaging import version -from typing import Optional from vllm import envs from vllm.config.model import LogprobsMode @@ -61,8 +60,7 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: logger.info_once("Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda elif envs.VLLM_USE_TRITOVOCAB_SIZE_SAMPLER is not False: - logger.info_once( - "Using Triton for top-p & top-k sampling.") + logger.info_once("Using Triton for top-p & top-k sampling.") self.forward = self.forward_triton else: logger.warning_once( @@ -74,15 +72,15 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: self.forward = self.forward_native else: if envs.VLLM_USE_TRITOVOCAB_SIZE_SAMPLER is not False: - logger.info_once( - "Using Triton for top-p & top-k sampling.") + logger.info_once("Using Triton for top-p & top-k sampling.") self.forward = self.forward_triton else: logger.warning_once( "FlashInfer is not available. Falling back to the " "PyTorch-native implementation of top-p & top-k " "sampling. For the best performance, please install " - "FlashInfer.") + "FlashInfer." + ) self.forward = self.forward_native elif current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.RISCV: @@ -120,10 +118,9 @@ def forward_triton( self, logits: torch.Tensor, generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - + k: torch.Tensor | None, + p: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: logits = self.apply_top_k_top_p_triton(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": @@ -200,6 +197,49 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return +def apply_top_k_top_p_with_index( + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. + """ + if p is None: + if k is None: + return logits + + # Avoid sorting vocab for top-k only case. + return apply_top_k_only(logits, k) + + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if k is not None: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits, logits_idx + + def apply_top_k_top_p( logits: torch.Tensor, k: torch.Tensor | None, @@ -245,68 +285,86 @@ def apply_top_k_top_p( def apply_top_k_top_p_triton( logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], + k: torch.Tensor | None, + p: torch.Tensor | None, ) -> torch.Tensor: - batch_size, vocab_size = logits.shape device_prop = torch.cuda.get_device_properties(logits.device) - VOCAB_SIZEUM_PROGRAMS = device_prop.multi_processor_count - BLOCK_SIZE = 16384 - SIGMA = 2.15 # Top 0.03 outliers - Maybe dynamically adjust based on K? + # VOCAB_SIZEUM_PROGRAMS = device_prop.multi_processor_count + VOCAB_SIZEUM_PROGRAMS = batch_size + BLOCK_SIZE = 8192 + SIGMA = 2 # Top 0.03 outliers - Maybe dynamically adjust based on K? VOCAB_SIZEUM_WARPS = 16 VOCAB_SIZEUM_STAGES = 3 - probs = torch.full((VOCAB_SIZEUM_PROGRAMS, vocab_size), - -float('inf'), - device=logits.device) - + probs = torch.full( + (VOCAB_SIZEUM_PROGRAMS, vocab_size), -float("inf"), device=logits.device + ) + debug_tensor = torch.full( + (VOCAB_SIZEUM_PROGRAMS, BLOCK_SIZE), -12.0, device=logits.device + ) if k is not None and p is None: - _topk_kernel[(VOCAB_SIZEUM_PROGRAMS, )](logits, - probs, - k, - batch_size, - SIGMA, - vocab_size, - BLOCK_SIZE, - num_warps=VOCAB_SIZEUM_WARPS, - num_stages=VOCAB_SIZEUM_STAGES) + _topk_kernel[(VOCAB_SIZEUM_PROGRAMS,)]( + logits, + probs, + debug_tensor, + k, + batch_size, + SIGMA, + vocab_size, + BLOCK_SIZE, + num_warps=VOCAB_SIZEUM_WARPS, + num_stages=VOCAB_SIZEUM_STAGES, + ) + + # print(f"debug_tensor: {debug_tensor[:, :13]}") elif k is None and p is not None: - probs_2 = torch.full_like(probs, -float('inf'), device=logits.device) - _topp_kernel[(VOCAB_SIZEUM_PROGRAMS, )](logits, - probs, - probs_2, - p, - batch_size, - SIGMA, - vocab_size, - BLOCK_SIZE, - num_warps=VOCAB_SIZEUM_WARPS, - num_stages=VOCAB_SIZEUM_STAGES) + probs_2 = torch.full_like(probs, -float("inf"), device=logits.device) + _topp_kernel[(VOCAB_SIZEUM_PROGRAMS,)]( + logits, + probs, + probs_2, + p, + batch_size, + SIGMA, + vocab_size, + BLOCK_SIZE, + num_warps=VOCAB_SIZEUM_WARPS, + num_stages=VOCAB_SIZEUM_STAGES, + ) elif k is not None and p is not None: - _topk_topp_kernel[(VOCAB_SIZEUM_PROGRAMS, )](logits, - probs, - k, - p, - batch_size, - SIGMA, - vocab_size, - BLOCK_SIZE, - num_warps=VOCAB_SIZEUM_WARPS, - num_stages=VOCAB_SIZEUM_STAGES) + _topk_topp_kernel[(VOCAB_SIZEUM_PROGRAMS,)]( + logits, + probs, + k, + p, + batch_size, + SIGMA, + vocab_size, + BLOCK_SIZE, + num_warps=VOCAB_SIZEUM_WARPS, + num_stages=VOCAB_SIZEUM_STAGES, + ) return logits @triton.jit -def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, VOCAB_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr): +def _topk_kernel( + LOGITS, + PROBS, + DEBUG_TENSOR, + K, + B, + SIGMA: tl.constexpr, + VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): pid = tl.program_id(0) num_programs = tl.num_programs(0) NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE for row_id in tl.range(pid, B, num_programs): k = tl.load(K + row_id) if k != VOCAB_SIZE: # All tokens are valid - # THERE IS VOCAB_SIZEO DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET FOR THIS TOP-K # CURREVOCAB_SIZET IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE IVOCAB_SIZECLUDES ALL DUPLICATE LOGITS, # WHICH MAY RETURVOCAB_SIZE MORE THAVOCAB_SIZE K LOGITS, @@ -315,47 +373,55 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, VOCAB_SIZE: tl.conste # IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE AVOCAB_SIZED IMPLEMEVOCAB_SIZET THE DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET # USIVOCAB_SIZEG THE FORCE_REMOVE_LOGIT VARIABLE - k_pivot = -float('inf') - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - PROBS_ROW = PROBS + pid * VOCAB_SIZE - + PROBS_ROW = PROBS + row_id * VOCAB_SIZE + DEBUG_TENSOR_ROW = DEBUG_TENSOR + row_id * BLOCK_SIZE search_addr = LOGITS_ROW search_range = VOCAB_SIZE search_iters = NUM_TILES - max_logit = -float('inf') - min_logit = float('inf') + k_pivot = -float("inf") + max_logit = -float("inf") + min_logit = float("inf") # Zeroth pass: Compute avg and std from a sample block # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE offs = tl.arange(0, BLOCK_SIZE) mask_n = offs < VOCAB_SIZE + num_valid = tl.sum(mask_n) logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / VOCAB_SIZE - sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE + avg_logit = tl.sum(logits_blk) / num_valid + sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_valid std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) outlier_pivot = avg_logit + SIGMA * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) + + sum_logit = tl.sum(logits_blk) + min_logit_value = tl.min(logits_blk) + max_logit_value = tl.max(logits_blk) + # First pass: compute max and min logits and gather outliers for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=avg_logit) + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) outlier_mask = (logits_blk > outlier_pivot) & mask_n num_blk_outliers = tl.sum(outlier_mask) cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) num_outliers += num_blk_outliers write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) + tl.store(DEBUG_TENSOR_ROW + 6, num_outliers) + tl.store(DEBUG_TENSOR_ROW + 7, max_logit) + tl.store(DEBUG_TENSOR_ROW + 8, min_logit) + max_range = max_logit min_range = min_logit if num_outliers > k: @@ -364,11 +430,15 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, VOCAB_SIZE: tl.conste search_addr = PROBS_ROW search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 + ) + + tl.store(DEBUG_TENSOR_ROW + 9, max_range) + tl.store(DEBUG_TENSOR_ROW + 10, min_range) # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 - while k_pivot == -float('inf') and num_iters < 18: + while k_pivot == -float("inf") and num_iters < 32: k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range @@ -379,9 +449,9 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, VOCAB_SIZE: tl.conste for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=-float('inf')) + logits_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") + ) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) @@ -409,31 +479,46 @@ def _topk_kernel(LOGITS, PROBS, K, B, SIGMA: tl.constexpr, VOCAB_SIZE: tl.conste max_range = k_pivot_2 num_iters += 1 - if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-8: + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-16: k_pivot = k_pivot_0 + tl.store(DEBUG_TENSOR_ROW + 11, num_iters) + tl.store(DEBUG_TENSOR_ROW + 12, k_pivot) + num_masked = tl.zeros((), dtype=tl.uint32) + # Third pass: Apply top-k mask - if k_pivot != -float('inf'): - for i in range(0, search_iters): + if k_pivot != -float("inf"): + for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - mask = (logits_blk > k_pivot) - logits_blk = tl.where(mask, logits_blk, -float('inf')) + mask = logits_blk > k_pivot + logits_blk = tl.where(mask, logits_blk, -float("inf")) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + num_masked += tl.sum(mask) + + tl.store(DEBUG_TENSOR_ROW + 13, num_masked) + offs_n = tl.arange(0, BLOCK_SIZE) @triton.jit -def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, - VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr): +def _topp_kernel( + LOGITS, + PROBS, + PROBS_2, + P, + B, + SIGMA: tl.constexpr, + VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): p = tl.load(P + row_id) if p != 1.0: # All tokens are valid - - p_pivot = -float('inf') + p_pivot = -float("inf") LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE PROBS_ROW = PROBS + pid * VOCAB_SIZE @@ -443,12 +528,12 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, search_range = VOCAB_SIZE search_iters = NUM_TILES - max_logit = -float('inf') - min_logit = float('inf') + max_logit = -float("inf") + min_logit = float("inf") # The Pytorch version removes the earlier duplicates # if there are more than one duplicates - force_remove_logit = -float('inf') + force_remove_logit = -float("inf") num_force_remove = tl.zeros((), dtype=tl.uint32) # Zeroth pass: Compute avg and std from a sample block @@ -471,9 +556,7 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=avg_logit) + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) @@ -482,9 +565,9 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) + probs_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) @@ -505,7 +588,8 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, sum_outlier_probs += tl.sum(outlier_mask * probs_blk) num_blk_outliers = tl.sum(outlier_mask) cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) num_outliers += num_blk_outliers write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) @@ -518,9 +602,10 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, search_addr = PROBS_2_ROW search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 + ) - second_max_logit = -float('inf') + second_max_logit = -float("inf") num_iters = 0 p_pivots_sum_0 = 0.0 @@ -528,7 +613,7 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, num_min_larger_0 = tl.zeros((), dtype=tl.uint32) # Fifth passes: Search for p_pivot (2log_2(n)) - while p_pivot == -float('inf') and num_iters < 32: + while p_pivot == -float("inf") and num_iters < 32: p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range p_pivots_sum_0 = 0.0 @@ -538,27 +623,19 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=0.0) + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, - probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, - tl.min(masked_larger_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - p_pivots_sum_0 += tl.sum(probs_blk * - (probs_blk > p_pivot_0)) + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=0.0) + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - num_min_larger_0 += tl.sum( - tl.abs(probs_blk - min_larger_0) < 1e-7) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) # Check if any of the pivots are equal to k if p_pivots_sum_0 >= p: @@ -579,53 +656,62 @@ def _topp_kernel(LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, elif num_min_larger_0 > 1: # Force remove duplicates (p_pivot is made to include all # duplicates if it falls on the duplicates) - num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, - tl.uint32) - force_remove_logit = tl.log( - min_larger_0 * sum_exp_logits) + max_logit + num_force_remove = tl.cast( + (p_pivots_sum_0 - p) / min_larger_0, tl.uint32 + ) + force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit # Sixth pass: Apply mask current_num_force_remove = tl.zeros((), dtype=tl.uint32) - if p_pivot != -float('inf'): + if p_pivot != -float("inf"): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) + logits_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) - if force_remove_logit != -float('inf'): + if force_remove_logit != -float("inf"): # Force remove duplicates - tolerance = 1e-5 * tl.maximum( - 1.0, tl.abs(force_remove_logit)) - force_remove_mask = tl.abs( - logits_blk - force_remove_logit) < tolerance - force_remove_count = tl.cumsum( - force_remove_mask) + current_num_force_remove - force_remove_count_mask = \ - force_remove_count <= num_force_remove - force_remove_mask = \ - force_remove_count_mask & force_remove_mask - logits_blk = tl.where(force_remove_mask, -float('inf'), - logits_blk) + tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) + force_remove_mask = ( + tl.abs(logits_blk - force_remove_logit) < tolerance + ) + force_remove_count = ( + tl.cumsum(force_remove_mask) + current_num_force_remove + ) + force_remove_count_mask = force_remove_count <= num_force_remove + force_remove_mask = force_remove_count_mask & force_remove_mask + logits_blk = tl.where( + force_remove_mask, -float("inf"), logits_blk + ) current_num_force_remove = tl.max(force_remove_count) - logits_blk = tl.where(logits_blk > p_pivot, logits_blk, - -float('inf')) + logits_blk = tl.where( + logits_blk > p_pivot, logits_blk, -float("inf") + ) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) @triton.jit -def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, - VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr): +def _topk_topp_kernel( + LOGITS, + PROBS, + K, + P, + B, + SIGMA: tl.constexpr, + VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, B, num_programs): - k_pivot = -float('inf') - p_pivot = -float('inf') + k_pivot = -float("inf") + p_pivot = -float("inf") LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE PROBS_ROW = PROBS + pid * VOCAB_SIZE @@ -634,13 +720,13 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, search_range = VOCAB_SIZE search_iters = NUM_TILES - max_logit = -float('inf') - min_logit = float('inf') - avg_logit = -float('inf') + max_logit = -float("inf") + min_logit = float("inf") + avg_logit = -float("inf") # The Pytorch version removes the earlier duplicates # if there are more than one duplicates - force_remove_logit = -float('inf') + force_remove_logit = -float("inf") num_force_remove = tl.zeros((), dtype=tl.uint32) # Zeroth pass: Compute avg and std from a sample block @@ -658,16 +744,15 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=avg_logit) + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) outlier_mask = (logits_blk > outlier_pivot) & mask_n num_blk_outliers = tl.sum(outlier_mask) cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) num_outliers += num_blk_outliers write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) @@ -682,10 +767,10 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, search_addr = PROBS_ROW search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 + ) if k != VOCAB_SIZE: # All tokens are valid - # THERE IS VOCAB_SIZEO DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET FOR THIS TOP-K # CURREVOCAB_SIZET IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE IVOCAB_SIZECLUDES ALL DUPLICATE LOGITS, # WHICH MAY RETURVOCAB_SIZE MORE THAVOCAB_SIZE K LOGITS, @@ -696,7 +781,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 - while k_pivot == -float('inf') and num_iters < 18: + while k_pivot == -float("inf") and num_iters < 18: k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range @@ -707,9 +792,9 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=-float('inf')) + logits_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") + ) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) @@ -746,8 +831,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, p = tl.load(P + row_id) if p != 1.0: # All tokens are valid - - second_max_logit = -float('inf') + second_max_logit = -float("inf") max_probs = 0.0 min_probs = 1.0 sum_exp_logits = 0.0 @@ -756,19 +840,17 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=-float('inf')) - probs_blk = tl.where(probs_blk > k_pivot, probs_blk, - -float('inf')) + probs_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") + ) + probs_blk = tl.where(probs_blk > k_pivot, probs_blk, -float("inf")) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) second_max_mask = probs_blk * (probs_blk < max_probs) - second_max_logit = tl.maximum(second_max_logit, - tl.max(second_max_mask)) + second_max_logit = tl.maximum(second_max_logit, tl.max(second_max_mask)) # Fourth pass: Compute probs (softmax) for i in range(0, search_iters): @@ -791,7 +873,7 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, num_min_larger_0 = tl.zeros((), dtype=tl.uint32) # Fifth passes: Search for p_pivot (2log_2(n)) - while p_pivot == -float('inf') and num_iters < 32: + while p_pivot == -float("inf") and num_iters < 32: p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range p_pivots_sum_0 = 0.0 @@ -801,27 +883,19 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, - mask=mask_n, - other=0.0) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, - probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, - tl.min(masked_larger_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - p_pivots_sum_0 += tl.sum(probs_blk * - (probs_blk > p_pivot_0)) + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, - mask=mask_n, - other=0.0) + probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - num_min_larger_0 += tl.sum( - tl.abs(probs_blk - min_larger_0) < 1e-7) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) # Check if any of the pivots are equal to k if p_pivots_sum_0 >= p: @@ -842,10 +916,10 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, elif num_min_larger_0 > 1: # Force remove duplicates (p_pivot is made to include all # duplicates if it falls on the duplicates) - num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, - tl.uint32) - force_remove_logit = tl.log( - min_larger_0 * sum_exp_logits) + max_logit + num_force_remove = tl.cast( + (p_pivots_sum_0 - p) / min_larger_0, tl.uint32 + ) + force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit @@ -854,32 +928,29 @@ def _topk_topp_kernel(LOGITS, PROBS, K, P, B, SIGMA: tl.constexpr, # Sixth pass: Apply mask pivot = tl.maximum(k_pivot, p_pivot) current_num_force_remove = tl.zeros((), dtype=tl.uint32) - if pivot != -float('inf'): + if pivot != -float("inf"): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) + logits_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) - if force_remove_logit != -float('inf'): + if force_remove_logit != -float("inf"): # Force remove duplicates - tolerance = 1e-5 * tl.maximum(1.0, - tl.abs(force_remove_logit)) - force_remove_mask = tl.abs(logits_blk - - force_remove_logit) < tolerance - force_remove_count = tl.cumsum( - force_remove_mask) + current_num_force_remove - force_remove_count_mask = \ - force_remove_count <= num_force_remove - force_remove_mask = \ - force_remove_count_mask & force_remove_mask - logits_blk = tl.where(force_remove_mask, -float('inf'), - logits_blk) + tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) + force_remove_mask = ( + tl.abs(logits_blk - force_remove_logit) < tolerance + ) + force_remove_count = ( + tl.cumsum(force_remove_mask) + current_num_force_remove + ) + force_remove_count_mask = force_remove_count <= num_force_remove + force_remove_mask = force_remove_count_mask & force_remove_mask + logits_blk = tl.where(force_remove_mask, -float("inf"), logits_blk) current_num_force_remove = tl.max(force_remove_count) - logits_blk = tl.where(logits_blk > pivot, logits_blk, - -float('inf')) + logits_blk = tl.where(logits_blk > pivot, logits_blk, -float("inf")) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) @@ -897,12 +968,12 @@ def apply_top_k_only( if k is None: return logits max_top_k = k.max().item() - + # --- FIX: Handle k=0 edge case --- # If the max k is 0, all rows are 0. Mask everything and exit. if max_top_k == 0: logits.fill_(-float("inf")) - return logits + return logits no_top_k_mask = k == logits.shape[1] # Set non-top-k rows to 1 so that we can gather. @@ -983,10 +1054,17 @@ def flashinfer_sample( return next_token_ids.view(-1) + @triton.jit def _topp_kernel_sorted( - LOGITS, PROBS, PROBS_2, P, B, SIGMA: tl.constexpr, - VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr + LOGITS, + PROBS, + PROBS_2, + P, + B, + SIGMA: tl.constexpr, + VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, ): """Modified top-p kernel with sort-equivalent tie-breaking and re-enabled outlier optimization. @@ -994,28 +1072,27 @@ def _topp_kernel_sorted( NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) - + for row_id in tl.range(pid, B, num_programs): p = tl.load(P + row_id) if p != 1.0: # All tokens are valid - - p_pivot = -float('inf') + p_pivot = -float("inf") LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE PROBS_ROW = PROBS + pid * VOCAB_SIZE - PROBS_2_ROW = PROBS_2 + pid * VOCAB_SIZE # <-- RE-ADDED + PROBS_2_ROW = PROBS_2 + pid * VOCAB_SIZE # <-- RE-ADDED # Default search params search_addr = PROBS_ROW search_range = VOCAB_SIZE search_iters = NUM_TILES - max_logit = -float('inf') - min_logit = float('inf') + max_logit = -float("inf") + min_logit = float("inf") - force_remove_logit = -float('inf') + force_remove_logit = -float("inf") num_force_remove = tl.zeros((), dtype=tl.uint32) - + # --- ZEROTH PASS (RE-ADDED) --- # Compute *exact* avg and std sum_logits = 0.0 @@ -1030,9 +1107,9 @@ def _topp_kernel_sorted( avg_logit = sum_logits / VOCAB_SIZE sq_avg_logit = sum_sq_logits / VOCAB_SIZE std_logit = tl.sqrt(tl.maximum(0.0, sq_avg_logit - avg_logit * avg_logit)) - outlier_pivot = avg_logit + SIGMA * std_logit # <-- RE-ADDED + outlier_pivot = avg_logit + SIGMA * std_logit # <-- RE-ADDED num_outliers = tl.zeros((), dtype=tl.uint32) # <-- RE-ADDED - sum_outlier_probs = 0.0 # <-- RE-ADDED + sum_outlier_probs = 0.0 # <-- RE-ADDED sum_exp_logits = 0.0 @@ -1040,9 +1117,9 @@ def _topp_kernel_sorted( for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) # Use -inf + logits_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) # Use -inf max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) @@ -1051,9 +1128,9 @@ def _topp_kernel_sorted( offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) + probs_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) @@ -1070,18 +1147,18 @@ def _topp_kernel_sorted( probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) probs_blk = probs_blk / sum_exp_logits tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - + # --- OUTLIER MASKIVOCAB_SIZEG LOGIC (RE-ADDED) --- outlier_mask = (probs_blk > outlier_prob) & mask_n sum_outlier_probs += tl.sum(outlier_mask * probs_blk) num_blk_outliers = tl.sum(outlier_mask) cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) num_outliers += num_blk_outliers write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) - max_range = tl.exp(max_logit - max_logit) / sum_exp_logits min_range = tl.exp(min_logit - max_logit) / sum_exp_logits @@ -1090,16 +1167,17 @@ def _topp_kernel_sorted( search_addr = PROBS_2_ROW search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 + ) - second_max_logit = -float('inf') + second_max_logit = -float("inf") num_iters = 0 - p_pivots_sum_0 = 0.0 # --> total prob including all equivalent min - min_larger_0 = 1.0 # --> prob of tie-breaking min + p_pivots_sum_0 = 0.0 # --> total prob including all equivalent min + min_larger_0 = 1.0 # --> prob of tie-breaking min num_min_larger_0 = tl.zeros((), dtype=tl.uint32) # Binary search for p_pivot - while p_pivot == -float('inf') and num_iters < 32: + while p_pivot == -float("inf") and num_iters < 32: p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range p_pivots_sum_0 = 0.0 @@ -1109,27 +1187,19 @@ def _topp_kernel_sorted( for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=0.0) + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, - probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, - tl.min(masked_larger_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - p_pivots_sum_0 += tl.sum(probs_blk * - (probs_blk > p_pivot_0)) + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=0.0) + probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - num_min_larger_0 += tl.sum( - tl.abs(probs_blk - min_larger_0) < 1e-7) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) if p_pivots_sum_0 >= p: if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: @@ -1146,46 +1216,48 @@ def _topp_kernel_sorted( if p_pivot >= max_logit: p_pivot = second_max_logit elif num_min_larger_0 > 1: - num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, - tl.uint32) # --> number of probs to be removed - force_remove_logit = tl.log( - min_larger_0 * sum_exp_logits) + max_logit + num_force_remove = tl.cast( + (p_pivots_sum_0 - p) / min_larger_0, tl.uint32 + ) # --> number of probs to be removed + force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit # Apply mask with (non-sort-equivalent) tie-breaking - current_num_removed = tl.zeros((), dtype=tl.uint32) - if p_pivot != -float('inf'): + current_num_removed = tl.zeros((), dtype=tl.uint32) + if p_pivot != -float("inf"): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) + logits_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) - if force_remove_logit != -float('inf'): + if force_remove_logit != -float("inf"): # Match PyTorch's non-sort-equivalent tie-breaking tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) is_tie = tl.abs(logits_blk - force_remove_logit) < tolerance tie_position = tl.cumsum(is_tie) - 1 + current_num_removed should_remove = is_tie & (tie_position < num_force_remove) - logits_blk = tl.where(should_remove, -float('inf'), logits_blk) + logits_blk = tl.where(should_remove, -float("inf"), logits_blk) current_num_removed += tl.sum(is_tie) # Standard threshold masking tolerance = 1e-6 * tl.maximum(1.0, tl.abs(p_pivot)) - logits_blk = tl.where(logits_blk >= (p_pivot - tolerance), logits_blk, - -float('inf')) - + logits_blk = tl.where( + logits_blk >= (p_pivot - tolerance), logits_blk, -float("inf") + ) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + def apply_top_p_sorted_equivalent( logits: torch.Tensor, p: torch.Tensor, - sigma: float = 3.0, + sigma: float = 3.0, ) -> torch.Tensor: """Apply top-p using binary search (no sort!) with sort-equivalent results. - + Args: logits: [B, VOCAB_SIZE] logits tensor p: [B] top-p thresholds @@ -1195,35 +1267,35 @@ def apply_top_p_sorted_equivalent( """ B, VOCAB_SIZE = logits.shape device = logits.device - + BLOCK_SIZE = triton.next_power_of_2(min(VOCAB_SIZE, 1024)) num_warps = 4 if BLOCK_SIZE < 2048 else 8 - + probs = torch.empty((B, VOCAB_SIZE), device=device, dtype=torch.float32) - probs_2 = torch.empty((B, VOCAB_SIZE), device=device, dtype=torch.float32) - + probs_2 = torch.empty((B, VOCAB_SIZE), device=device, dtype=torch.float32) + grid = (B,) _topp_kernel_sorted[grid]( logits, probs, - probs_2, + probs_2, p, B, - SIGMA=sigma, + SIGMA=sigma, VOCAB_SIZE=VOCAB_SIZE, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) - + return logits + def apply_top_k_top_p_test( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, ) -> torch.Tensor: - """Optimized implementation combining torch.topk and binary search kernel. - """ + """Optimized implementation combining torch.topk and binary search kernel.""" if p is None: if k is None: return logits @@ -1231,24 +1303,25 @@ def apply_top_k_top_p_test( # Apply top-k filter first if needed if k is not None: logits = apply_top_k_only(logits, k) - + # Apply top-p using binary search (no sort!) return apply_top_p_sorted_equivalent(logits, p) + # -------------------------------------------------------------------------- -@triton.jit +@triton.jit def top_p_pivot_filter( - LOGITS, - PROBS, - PROBS_IDX, - K_FILTER: tl.int32, - OUTPUT_LOGITS, - OUTPUT_INDICES, - B, # --> batch size + LOGITS, + PROBS, + PROBS_IDX, + K_FILTER: tl.int32, + OUTPUT_LOGITS, + OUTPUT_INDICES, + B, # --> batch size SIGMA: tl.constexpr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, - ): +): NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) @@ -1256,41 +1329,40 @@ def top_p_pivot_filter( for row_id in tl.range(pid, B, num_programs): k = K_FILTER - if k <= VOCAB_SIZE: - k_pivot = -float('inf') + if k <= VOCAB_SIZE: + k_pivot = -float("inf") - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE PROBS_ROW = PROBS + row_id * VOCAB_SIZE PROBS_IDX_ROW = PROBS_IDX + row_id * VOCAB_SIZE OUTPUT_LOGITS_ROW = OUTPUT_LOGITS + row_id * K_FILTER OUTPUT_INDICES_ROW = OUTPUT_INDICES + row_id * K_FILTER - search_addr = LOGITS_ROW - search_range = VOCAB_SIZE + search_addr = LOGITS_ROW + search_range = VOCAB_SIZE search_iters = NUM_TILES - max_logit = -float('inf') - min_logit = float('inf') + max_logit = -float("inf") + min_logit = float("inf") - # Zeroth pass: Compute avg and std from a sample block + # Zeroth pass: Compute avg and std from a sample block offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE + mask_n = offs < VOCAB_SIZE + num_mask = tl.sum(mask_n) logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) probs_blk = tl.load(PROBS_ROW + offs, mask=mask_n, other=0.0) valid_count = tl.sum(mask_n, dtype=tl.float32) - avg_logit = tl.sum(logits_blk) / VOCAB_SIZE # re-check - sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE # re-check + avg_logit = tl.sum(logits_blk) / num_mask # re-check + sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_mask # re-check std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) outlier_pivot = avg_logit + SIGMA * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=avg_logit) + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) @@ -1299,7 +1371,8 @@ def top_p_pivot_filter( outlier_mask = (logits_blk > outlier_pivot) & mask_n num_blk_outliers = tl.sum(outlier_mask) cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) num_outliers += num_blk_outliers write_idx = tl.where(outlier_mask, cumulative_pos, 0) @@ -1314,11 +1387,12 @@ def top_p_pivot_filter( search_addr = PROBS_ROW search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 + ) + # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 - while k_pivot == -float('inf') and num_iters < 18: + while k_pivot == -float("inf") and num_iters < 18: k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range @@ -1329,9 +1403,9 @@ def top_p_pivot_filter( for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, - mask=mask_n, - other=-float('inf')) + logits_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") + ) k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) @@ -1363,74 +1437,95 @@ def top_p_pivot_filter( k_pivot = k_pivot_0 # ============== Third pass : Apply top-k mask ================ write_pos = tl.zeros((), dtype=tl.int32) - # if k_pivot != -float('inf'): - for i in range(0, NUM_TILES): + # if k_pivot != -float('inf'): + for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - keep_mask = (logits_blk >= k_pivot) & mask_n + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) + keep_mask = (logits_blk >= k_pivot) & mask_n n_kept = tl.sum(keep_mask, dtype=tl.int32) - cpos = tl.cumsum(keep_mask) -1 + write_pos - final_mask = keep_mask & (cpos < k) + cpos = tl.cumsum(keep_mask) - 1 + write_pos + final_mask = keep_mask & (cpos < k) write_idx = tl.where(final_mask, cpos, 0) tl.store(OUTPUT_LOGITS_ROW + write_idx, logits_blk, mask=final_mask) tl.store(OUTPUT_INDICES_ROW + write_idx, offs_n, mask=final_mask) write_pos += tl.sum(final_mask, dtype=tl.int32) - -def apply_top_p_filtered ( - logits: torch.Tensor, +def apply_top_k_top_p_filtered( + logits: torch.Tensor, + k: torch.Tensor, p: torch.Tensor, -) -> torch.Tensor: +) -> torch.Tensor: """ Applies top p using pivot based filtering """ batch_size, vocab_size = logits.shape BLOCK_SIZE = 2048 - SIGMA = 2.15 + SIGMA = 2.15 NUM_WARPS = 16 NUM_STAGES = 3 - if not torch.any(p < 1.0): - return logits - # ================= to find the k filter value ==================== - k_filter = int(vocab_size * 1/32) + k_filter = max(int(vocab_size * 1 / 32), k.max() * 2) + if k_filter > vocab_size / 2: + print(f"k_filter too big: {k_filter}") + return apply_top_k_top_p(logits, k, p) - filtered_logits = torch.full((batch_size, k_filter), -float('inf'), device=logits.device) - filtered_indices = torch.full((batch_size, k_filter), 0, dtype=torch.int32, device=logits.device) + filtered_logits = torch.full( + (batch_size, k_filter), -float("inf"), device=logits.device + ) + filtered_indices = torch.full( + (batch_size, k_filter), 0, dtype=torch.int32, device=logits.device + ) - probs = torch.empty((batch_size, vocab_size), device=logits.device, dtype=torch.float32) + probs = torch.empty( + (batch_size, vocab_size), device=logits.device, dtype=torch.float32 + ) probs_idx = torch.empty_like(probs, dtype=torch.int32) grid = (batch_size,) top_p_pivot_filter[grid]( - logits, # --> input - probs, # initial filtered - probs_idx, # initial filtered index - k_filter, # --> scalar - filtered_logits, # --> output, filtered - filtered_indices, # --> filtered logits indices - batch_size, - SIGMA=SIGMA, + logits, # --> input + probs, # initial filtered + probs_idx, # initial filtered index + k_filter, # --> scalar + filtered_logits, # --> output, filtered + filtered_indices, # --> filtered logits indices + batch_size, + SIGMA=SIGMA, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, - ) + ) + + target_logit = filtered_logits + + if p is None and k is not None: + filtered_logits = apply_top_k_only(filtered_logits, k) + logits.fill_(-float("inf")) + logits.scatter_(dim=1, index=filtered_indices, src=filtered_logits) + return logits + assert False - # this kernel outputs filtered_logits and filtered_indices of shape (batch_size, k_filter) - logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) - logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) + if p is not None: + # this kernel outputs filtered_logits and filtered_indices of shape (batch_size, k_filter) + logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) + logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) + + logits_softmax = logits.softmax(dim=-1) + sorted_probs = torch.gather(logits_softmax, -1, logits_sort_indices) - logits_softmax = logits.softmax(dim=-1) - sorted_probs = torch.gather(logits_softmax, -1, logits_sort_indices) + sum_probs = sorted_probs.sum(dim=-1) - sum_probs = sorted_probs.sum(dim=-1) + if torch.any(sum_probs < p): + return apply_top_k_top_p(logits, k, p) - if torch.any(sum_probs < p): - print(f"edge case --> fall back !") - assert False + if torch.any(sum_probs < p): + print("edge case --> fall back !") + # assert False return apply_top_k_top_p(logits, k=None, p=p) probs_sum = torch.cumsum(sorted_probs, dim=-1) @@ -1443,387 +1538,23 @@ def apply_top_p_filtered ( logits.fill_(-float("inf")) logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) return logits - - - - # @triton.jit -# def top_p_pivot_filter(LOGITS, L, MIN_IDX, PROBS, PROBS_2, idx_tensor, P, B, SIGMA:tl.constexpr, VOCAB_SIZE:tl.constexpr, BLOCK_SIZE:tl.constexpr): -# NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE -# pid = tl.program_id(0) -# num_programs = tl.num_programs(0) - -# for row_id in tl.range(pid, B, num_programs): -# p = tl.load(P + row_id) # fetches the p value of the row it is working on -# # if p != 1.0: # if p == 1, this becomes pointless ! -# if True: -# p_pivot = -float('inf') - -# LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE -# PROBS_ROW = PROBS + row_id * VOCAB_SIZE -# PROBS_2_ROW = PROBS_2 + row_id * VOCAB_SIZE -# IDX_ROW = idx_tensor + row_id * VOCAB_SIZE - -# search_address = PROBS_ROW -# search_range = VOCAB_SIZE -# search_iters = NUM_TILES - -# max_logit = -float('inf') -# min_logit = float('inf') -# min_logit_idx = -1 - -# force_remove_logit = -float('inf') # for handling duplicate cases (edge case) -# num_force_remove = tl.zeros((), dtype=tl.uint32) - -# # First Pass: Compute avg and std from a sample block -# offs = tl.arange(0, BLOCK_SIZE) -# mask_n = offs < VOCAB_SIZE -# logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) -# avg_logit = tl.sum(logits_blk) / VOCAB_SIZE -# sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE -# std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - -# outlier_pivot = avg_logit + SIGMA * std_logit -# num_outliers = tl.zeros((), dtype=tl.uint32) -# sum_outlier_probs = 0.0 -# sum_exp_logits = 0.0 - -# # ====== Second Pass: compute max and min logits ====== -# for i in range(0, search_iters): -# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask_n = offs_n < search_range -# logits_blk = tl.load(LOGITS_ROW + offs_n, -# mask=mask_n, -# other=avg_logit) -# max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - -# for i in range(0, search_iters): -# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask_n = offs_n < search_range -# logits_blk = tl.load(LOGITS_ROW + offs_n, -# mask=mask_n, -# other=float('inf')) -# local_min_logit = tl.min(logits_blk) -# local_min_logit_idx = tl.argmin(logits_blk, axis=0) -# if local_min_logit < min_logit: -# min_logit = local_min_logit -# min_logit_idx = local_min_logit_idx + i * BLOCK_SIZE - -# # ====== Third pass: Calculate exp logits and sum ====== -# for i in range(0, search_iters): -# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask_n = offs_n < search_range - -# probs_blk = tl.load(LOGITS_ROW + offs_n, -# mask=mask_n, -# other=-float('inf')) -# probs_blk = probs_blk - max_logit -# probs_blk = tl.exp(probs_blk) -# sum_exp_logits += tl.sum(probs_blk) -# tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) -# outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits - -# # ====== Fourth pass: Calculate probs and get outliers ====== -# for i in range(0, search_iters): -# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask_n = offs_n < search_range - -# probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) -# probs_blk = probs_blk / sum_exp_logits -# tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - -# outlier_mask = (probs_blk > outlier_prob) & mask_n -# sum_outlier_probs += tl.sum(outlier_mask * probs_blk) -# num_blk_outliers = tl.sum(outlier_mask) -# cumulative_pos = tl.cast( -# tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) -# num_outliers += num_blk_outliers -# write_pos = tl.where(outlier_mask, cumulative_pos, -1) -# tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) # stores the final probs after masking to PROBS_2 - -# max_range = tl.exp(max_logit - max_logit) / sum_exp_logits -# min_range = tl.exp(min_logit - max_logit) / sum_exp_logits - -# if sum_outlier_probs > p: -# min_range = outlier_prob -# search_address = PROBS_2_ROW -# search_range = tl.cast(num_outliers, tl.int32) -# search_iters = tl.cast( -# (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32) - -# second_max_logit = -float('inf') - -# num_iters = 0 -# p_pivots_sum_0 = 0.0 -# min_larger_0 = 1.0 -# num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - -# # ====== Fifth Passes: Search for p_pivot(2log_2(n)) ====== -# while p_pivot == -float('inf') and num_iters < 32: -# p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range -# p_pivots_sum_0 = 0.0 - -# min_larger_0 = 1.0 -# num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - -# for i in range(0, search_iters): -# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask_n = offs_n < search_range -# probs_blk = tl.load(search_address + offs_n, -# mask=mask_n, -# other=0.0) - -# masked_larger_0 = tl.where(probs_blk > p_pivot_0, -# probs_blk, 1.0) -# min_larger_0 = tl.minimum(min_larger_0, -# tl.min(masked_larger_0)) - -# p_pivots_sum_0 += tl.sum(probs_blk * -# (probs_blk > p_pivot_0)) - -# for i in range(0, search_iters): -# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask_n = offs_n < search_range -# probs_blk = tl.load(search_address + offs_n, -# mask=mask_n, -# other=0.0) - -# num_min_larger_0 += tl.sum( -# tl.abs(probs_blk - min_larger_0) < 1e-7) - -# # Check if any of the pivots are equal to k -# if p_pivots_sum_0 >= p: -# if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: -# p_pivot = p_pivot_0 -# else: -# min_range = p_pivot_0 -# else: -# max_range = p_pivot_0 - -# num_iters += 1 -# if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: -# p_pivot = p_pivot_0 - -# # Force remove duplicates (p_pivot is made to include all -# # duplicates if it falls on the duplicates) -# num_force_remove = tl.cast((p_pivots_sum_0 - p) / min_larger_0, -# tl.uint32) -# force_remove_logit = tl.log( -# min_larger_0 * sum_exp_logits) + max_logit - -# # p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit -# p_pivot_logit = -float('inf') - -# # -------- widen cutoff by 10% ---------------- -# if p_pivot != -float('inf'): -# p_pivot_logit = tl.log(p_pivot * sum_exp_logits) + max_logit -# kept_write_pos = 0 - -# for i in range(0, NUM_TILES): -# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask_n = offs_n < VOCAB_SIZE -# # padding all idx with min_logit_idx to avoid -1 or 0 being the padding value -# tl.store(IDX_ROW + offs_n, min_logit_idx, mask=mask_n) - -# logits_blk = tl.load(LOGITS_ROW + offs_n, -# mask=mask_n, -# other=-float('inf')) -# probs_blk = tl.load(PROBS_ROW + offs_n, -# mask=mask_n, -# other=0.0) - -# keep_mask = logits_blk > p_pivot_logit -# out_vals = tl.where(keep_mask, logits_blk, -float('inf')) -# tl.store(LOGITS_ROW + offs_n, out_vals, mask=mask_n) - -# # ====== keeping track of L and idx_tensor ====== -# n_kept = tl.sum(keep_mask, dtype=tl.int32) -# if n_kept > 0: -# cpos = tl.cast(tl.cumsum(keep_mask) - 1 + kept_write_pos, tl.int32) -# write_idx = tl.where(keep_mask, cpos, 0) - -# tl.store(IDX_ROW + write_idx, offs_n, mask=keep_mask) -# tl.store(PROBS_2_ROW + write_idx, probs_blk, mask=keep_mask) - -# kept_write_pos += n_kept -# tl.store(L + row_id, tl.cast(kept_write_pos, tl.int32)) -# tl.store(MIN_IDX + row_id, tl.cast(min_logit_idx, tl.int32)) - -# @triton.jit -# def apply_mask( -# filtered_logits_sort, -# temp_logits, -# filtered_probs_sort, -# temp_probs, -# sort_indices, -# outliers, -# outliers_sorted, -# B, -# MAX_L, -# BLOCK_SIZE: tl.constexpr, -# ): -# pid = tl.program_id(0) -# num_programs = tl.num_programs(0) # total number of programs launched - -# for row_id in tl.range(pid, B, num_programs): # each row -# SORTED_LOGITS_ROW = filtered_logits_sort + row_id * MAX_L -# SORTED_PROBS_ROW = filtered_probs_sort + row_id * MAX_L -# SORTED_IDX_ROW = sort_indices + row_id * MAX_L -# OUTLIERS_ROW = outliers + row_id * MAX_L - -# OUTPUT_LOGITS_ROW = temp_logits + row_id * MAX_L -# OUTPUT_PROBS_ROW = temp_probs + row_id * MAX_L -# OUTPUT_OUTLIERS_SORTED_ROW = outliers_sorted + row_id * MAX_L - -# num_tiles = (MAX_L + BLOCK_SIZE - 1) // BLOCK_SIZE # for copying - -# for i in range(0, num_tiles): # partitioning the row -# offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask_n = offs_n < MAX_L - -# sort_idx_blk = tl.load(SORTED_IDX_ROW + offs_n, mask=mask_n, other=0) -# probs_blk = tl.load(SORTED_PROBS_ROW + sort_idx_blk, mask=mask_n, other=0.0) -# logits_blk = tl.load(SORTED_LOGITS_ROW + sort_idx_blk, mask=mask_n, other=-float('inf')) -# outliers_blk = tl.load(OUTLIERS_ROW + sort_idx_blk, mask=mask_n, other=0) - -# tl.store(OUTPUT_PROBS_ROW + offs_n, probs_blk, mask=mask_n) -# tl.store(OUTPUT_LOGITS_ROW + offs_n, logits_blk, mask=mask_n) -# tl.store(OUTPUT_OUTLIERS_SORTED_ROW + offs_n, outliers_blk, mask=mask_n) - -# def apply_top_p_filtered ( -# logits: torch.Tensor, -# p: torch.Tensor, -# debug: bool = False -# ) -> torch.Tensor: -# """ -# Applies top-p using pivot-based filtering -# """ -# logits_copy = logits.clone().detach() -# batch_size, vocab_size = logits.shape - -# probs = torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) -# probs_2 = torch.zeros((batch_size, vocab_size), device=logits.device) - -# l = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) -# min_idx = torch.zeros_like(l, dtype=torch.int32) -# idx_tensor = torch.full_like(logits, -1, dtype=torch.int32) - -# BLOCK_SIZE = 2048 -# SIGMA = 2.15 -# NUM_WARPS = 16 -# NUM_STAGES = 3 - -# if not torch.any(p < 1.0): -# return logits - -# p_widened = torch.clamp(p*1.01, max=0.999) - -# grid = lambda meta: ((batch_size + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], ) -# top_p_pivot_filter[grid]( -# logits, -# l, -# min_idx, -# probs, -# probs_2, -# idx_tensor, -# p_widened, -# batch_size, -# SIGMA=SIGMA, -# VOCAB_SIZE=vocab_size, -# BLOCK_SIZE=BLOCK_SIZE, -# ) - -# max_l = torch.max(l).item() - -# if max_l == 0: -# return torch.full((batch_size, vocab_size), -float('inf'), device=logits.device) - -# outliers = idx_tensor[:, :max_l] - -# # this is for gather, which puts in garbage value for those not included. I wanted to mask it out. -# valid_mask = torch.arange(max_l, device=logits.device).unsqueeze(0) < l.unsqueeze(1) - -# filtered_full_probs = probs_2[:, :max_l] -# filtered_logits = torch.gather(logits, 1, outliers) - -# filtered_probs_sort = torch.where(valid_mask, filtered_full_probs, 0.0) -# filtered_logits = torch.where(valid_mask, filtered_logits, torch.tensor(-float('inf'), device=logits.device)) -# sum_outliers = (filtered_full_probs * valid_mask.float()).sum(dim=1) -# sum_non_outliers = 1.0 - sum_outliers - -# # we must do sort in python, unfortunately..... -# filtered_logits_sort, sort_indices = torch.sort( -# filtered_logits, dim=-1, descending=False -# ) - -# # outliers_sorted = torch.gather(outliers, 1, sort_indices) -# # filtered_probs_sort = torch.gather(filtered_probs_sort, 1, sort_indices) - -# # =================== calling a new triton kernel =============================== -# temp_logits = torch.empty_like(filtered_logits_sort) -# temp_probs = torch.empty_like(filtered_probs_sort) -# outliers_sorted = torch.empty_like(outliers, dtype=torch.int32) - -# grid = (batch_size, ) -# apply_mask[grid]( -# filtered_logits_sort, -# temp_logits, -# filtered_probs_sort, -# temp_probs, -# sort_indices, -# outliers, -# outliers_sorted, -# batch_size, -# max_l, -# BLOCK_SIZE=BLOCK_SIZE, -# ) -# valid_mask_sorted = torch.gather(valid_mask, 1, sort_indices) - -# probs_sum = sum_non_outliers.unsqueeze(1) + torch.cumsum(temp_probs, dim=-1) -# top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) -# top_p_mask = top_p_mask | ~valid_mask_sorted -# top_p_mask[:, -1] = False - -# temp_logits.masked_fill_(top_p_mask, -float("inf")) - -# # outliers_sorted = torch.gather(outliers, 1, sort_indices) - -# logits.fill_(-float("inf")) -# logits.scatter_(dim=1, index=outliers_sorted, src=temp_logits) - -# return logits - -# # valid_mask_sorted = torch.gather(valid_mask, 1, sort_indices) - -# # probs_sum = sum_non_outliers.unsqueeze(1) + torch.cumsum(filtered_probs_sort, dim=-1) -# # top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) -# # top_p_mask = top_p_mask | ~valid_mask_sorted -# # top_p_mask[:, -1] = False -# # filtered_logits_sort.masked_fill_(top_p_mask, -float("inf")) - -# # outliers_sorted = torch.gather(outliers, 1, sort_indices) -# # logits.fill_(-float("inf")) -# # logits.scatter_(dim=1, index=outliers_sorted, src=filtered_logits_sort) - -# # return logits + def apply_top_k_top_p_test2( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, - debug: bool = False + debug: bool = False, ) -> torch.Tensor: """ - Uses pivot-based algorithm to filter --> sort + Uses pivot-based algorithm to filter --> sort """ - if p.max() > 0.99: - return apply_top_k_top_p(logits, None, p) - if k is None and p is None: - return logits - elif p is None and k is not None: - return apply_top_k_only(logits, k) + if k is None and p is None: + return logits + elif p is None and k is not None: + return apply_top_k_top_p_triton(logits, k, p) elif k is None and p is not None: - return apply_top_p_filtered(logits, p) + return apply_top_k_top_p_filtered(logits, k, p) else: - logits_k = apply_top_k_only(logits, k) - return apply_top_p_filtered(logits_k, p, debug) - \ No newline at end of file + logits_k = apply_top_k_top_p_triton(logits, k, p=None) + return apply_top_k_top_p_filtered(logits, k, p) From a2f6ae61720bfd7fa61e70bd65e21817ca60a656 Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 14 Nov 2025 18:38:48 -0800 Subject: [PATCH 49/99] Errors on top p Signed-off-by: js_park --- compare.py | 242 +--- vllm/v1/sample/ops/topk_topp_sampler.py | 1467 ++++++----------------- 2 files changed, 371 insertions(+), 1338 deletions(-) diff --git a/compare.py b/compare.py index 2ab256a4b9ad..4fae145ee54c 100644 --- a/compare.py +++ b/compare.py @@ -17,7 +17,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import ( apply_top_k_top_p, - apply_top_k_top_p_test2, + apply_top_k_top_p_triton, ) @@ -56,15 +56,18 @@ def test_accuracy(logits, k, p, func_list): original_logits_bin = original_logits.view(torch.int32) output_logits_bin = output_logits.view(torch.int32) is_correct = torch.all(original_logits_bin == output_logits_bin) + is_correct = is_correct and torch.allclose( + output_logits, original_logits, atol=1e-16 + ) output_correct_list.append(is_correct) func_name = func_list[i].__name__ if not is_correct: print_to_log( - r_str(f"Error: logits are not close on {i} - " + f"{func_name}"), + r_str("Error: logits are not close on " + f"{func_name}"), log_file, ) - output_logits = apply_top_k_top_p_test2(logits, k, p, debug=True) + output_logits = func_list[i](logits, k, p, debug=True) error_mask = torch.abs(output_logits - original_logits) > 1e-16 error_rows = torch.where(error_mask)[0] error_rows = torch.unique(error_rows) @@ -117,10 +120,11 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): batch_size_list = [64, 128, 1024] vocab_size_list = [4096, 16384, 65536, 128000, 262144] - # p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] - p_list = [None] - k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] - func_list = [apply_top_k_top_p, apply_top_k_top_p_test2] + p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] + # p_list = [None] + # k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + k_list = [None] + func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] log_file = f"triton_topk_topp_test_{date_str}.log" csv_file = f"triton_topk_topp_test_{date_str}.csv" @@ -213,227 +217,3 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): f"{time_list[0] / time_list[1]:.8f}\n" ) print_to_log(y_str("--------------------------------\n"), log_file) - -"""# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from datetime import datetime -from itertools import product -import regex as re -import torch - -print("Torch version:", torch.__version__) -print("CUDA available:", torch.cuda.is_available()) -print("Default device:", torch.cuda.current_device()) - -# --- MODIFIED IMPORTS --- -# We need all the component kernels to time them individually -from vllm.v1.sample.ops.topk_topp_sampler import ( - apply_top_k_top_p, - apply_top_k_with_pivot_filter, # Used for accuracy check - apply_top_k_only, # This is the baseline AND our Kernel 2 - top_k_pivot_and_sort, # Kernel 1 - scatter_topk_kernel # Kernel 3 -) -print("All kernels imported successfully") - -x = torch.randn(2, 5, device="cuda") -y = apply_top_k_only(x, k=torch.tensor([2,2], device="cuda")) -print("apply_top_k_only ran successfully, output:", y) - - -def g_str(s): return "\033[32m" + s + "\033[0m" -def r_str(s): return "\033[31m" + s + "\033[0m" -def y_str(s): return "\033[33m" + s + "\033[0m" -def b_str(s): return "\033[34m" + s + "\033[0m" - -def print_to_log(s, log_file): - print(s) - s = re.sub(r"\033[[0-9;]*m", "", s) - with open(log_file, "a") as f: - f.write(s + "\n") - -# --- UNCHANGED --- -# test_accuracy still runs the *full* pipeline to check for correctness -def test_accuracy(logits, k, log_file): - input_logits_torch = logits.clone().detach() - input_logits_triton = logits.clone().detach() - - original_logits = apply_top_k_only(input_logits_torch, k) - triton_pivot_logits = apply_top_k_with_pivot_filter(input_logits_triton, k) - - torch.cuda.synchronize() - is_correct = torch.allclose(original_logits, triton_pivot_logits) - - if not is_correct: - print_to_log(r_str("Error: logits are not close"), log_file) - - return is_correct - -# --- REWRITTEN test_time FUNCTION --- -def test_time(logits, k, num_runs=30, num_warmup=5): - - batch_size, vocab_size = logits.shape - - # --- Warmup --- - for _ in range(num_warmup): - warmup_tensor_torch = logits.clone().detach() - apply_top_k_only(warmup_tensor_torch, k) - - warmup_tensor_triton = logits.clone().detach() - apply_top_k_with_pivot_filter(warmup_tensor_triton, k) - torch.cuda.synchronize() - - # --- 1. Baseline `apply_top_k_only` timing --- - start_torch = torch.cuda.Event(enable_timing=True) - end_torch = torch.cuda.Event(enable_timing=True) - - start_torch.record() - for i in range(num_runs): - input_tensor = logits.clone().detach() - apply_top_k_only(input_tensor, k) - end_torch.record() - torch.cuda.synchronize() - apply_top_k_time = start_torch.elapsed_time(end_torch) / num_runs - - # --- 2. Triton Kernel 2 (Sort) Timing --- - - # Events for Kernel 2 - start_k2 = torch.cuda.Event(enable_timing=True) - end_k2 = torch.cuda.Event(enable_timing=True) - - # Kernel 2 time accumulator - triton_k2_time_acc = 0.0 - - for i in range(num_runs): - if (k == vocab_size).all(): - continue - - # 1. Setup - input_tensor = logits.clone().detach() - probs = torch.full_like(input_tensor, -float('inf')) - l = torch.empty((batch_size,), device=input_tensor.device, dtype=torch.int32) - idx_tensor = torch.full_like(input_tensor, -1, dtype=torch.int) - - BLOCK_SIZE = 1024 - SIGMA = 2.0 - grid_pivot = (batch_size,) - - # 2. Run Kernel 1 (Pivot) - *No timer* - top_k_pivot_and_sort[grid_pivot]( - input_tensor, probs, l, idx_tensor, k, batch_size, - SIGMA=SIGMA, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, - ) - - torch.cuda.synchronize() - max_l = torch.max(l).item() - outliers = probs[:, :max_l] - outliers_idx = idx_tensor[:, :max_l] - k_pinned = torch.minimum(k, l) - - # 4. Time Kernel 2 (Sort) - start_k2.record() - apply_top_k_only(outliers, k_pinned) - end_k2.record() - - torch.cuda.synchronize() - triton_k2_time_acc += start_k2.elapsed_time(end_k2) - - triton_sort_only_time = triton_k2_time_acc / num_runs - - return apply_top_k_time, triton_sort_only_time - - -def main(): - print("Starting compare.py...") - date_str = datetime.now().strftime("%Y%m%d_%H%M%S") - - #batch_size_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] # Up to 512 - #vocab_size_list = [4096, 16384, 65536, 262144, 102400] - #k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] - - batch_size_list = [1, 2, 4, 8] - vocab_size_list = [4096, 16384] - k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] - - - log_file = f"triton_topk_topp_test_{date_str}.log" - csv_file = f"triton_topk_topp_test_{date_str}.csv" - - print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) - print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) - print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) - print_to_log(y_str("k_list:") + f"{k_list}", log_file) - print_to_log(y_str("log_file:") + f"{log_file}", log_file) - print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) - - # --- MODIFIED CSV HEADER --- - with open(csv_file, "w") as f: - f.write("dist_generator,batch_size,vocab_size,k,is_correct," - "apply_top_k_time,triton_sort_only_time,speedup_vs_baseline\n") - - for batch_size, vocab_size, k in product(batch_size_list, - vocab_size_list, - k_list): - - logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 - logits_list = [("RANDN", logits_randn)] - - if k == "RAND": - k_tensor = torch.randint(1, - vocab_size, (batch_size,), - device="cuda") - elif k is not None: - k_val = min(k, vocab_size) # Ensure k is not > vocab_size - k_tensor = torch.full((batch_size,), k_val, device="cuda") - else: - k_tensor = torch.full((batch_size,), vocab_size, device="cuda") - - for dist_generator, logits in logits_list: - print_to_log(y_str("--------------------------------"), log_file) - print_to_log( - g_str("Testing ") + f"{dist_generator}" + - y_str(" with batch_size: ") + f"{batch_size}" + - y_str(" vocab_size: ") + f"{vocab_size}" + - y_str(" k: ") + f"{k}", log_file) - - is_correct = test_accuracy(logits, k_tensor, log_file) - if not is_correct: - print_to_log( - r_str(f"Error: logits are not close for batch_size: {batch_size}, " - f"vocab_size: {vocab_size}, dist_generator: {dist_generator}, k: {k}"), - log_file) - - # --- MODIFIED TIMING CALL --- - apply_top_k_time, triton_sort_only_time = test_time(logits, k_tensor) - - print_to_log( - b_str("apply_top_k_time (Baseline): ") + f"{apply_top_k_time}", log_file) - print_to_log( - b_str("triton_sort_only_time (Kernel 2): ") + f"{triton_sort_only_time}", - log_file) - - # --- THIS IS THE FIX --- - # Handle the k: None case where triton_sort_only_time is 0.0 - if triton_sort_only_time > 0: - speedup = apply_top_k_time / triton_sort_only_time - speedup_str = f"{speedup:.8f}x" - else: - # 'k: None' case, speedup is not applicable (N/A) - speedup = 0.0 - speedup_str = "N/A (passthrough)" - # --- END FIX --- - - print_to_log( - g_str("Triton Sort Speedup vs. Full Baseline: ") + - speedup_str, log_file) - - # Write to CSV - with open(csv_file, "a") as f: - f.write(f"{dist_generator},{batch_size},{vocab_size},{k}," - f"{is_correct},{apply_top_k_time},{triton_sort_only_time}," - f"{speedup:.8f}\n") # Still write the float for CSV - print_to_log(y_str("--------------------------------\n"), log_file) - -if __name__ == "__main__": - main()""" diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index a48186982d7a..14e222ed0a7a 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -197,49 +197,6 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return -def apply_top_k_top_p_with_index( - logits: torch.Tensor, - k: torch.Tensor | None, - p: torch.Tensor | None, -) -> torch.Tensor: - """Apply top-k and top-p masks to the logits. - - If a top-p is used, this function will sort the logits tensor, - which can be slow for large batches. - - The logits tensor may be updated in-place. - """ - if p is None: - if k is None: - return logits - - # Avoid sorting vocab for top-k only case. - return apply_top_k_only(logits, k) - - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - if k is not None: - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - if p is not None: - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits, logits_idx - - def apply_top_k_top_p( logits: torch.Tensor, k: torch.Tensor | None, @@ -283,677 +240,6 @@ def apply_top_k_top_p( return logits -def apply_top_k_top_p_triton( - logits: torch.Tensor, - k: torch.Tensor | None, - p: torch.Tensor | None, -) -> torch.Tensor: - batch_size, vocab_size = logits.shape - - device_prop = torch.cuda.get_device_properties(logits.device) - # VOCAB_SIZEUM_PROGRAMS = device_prop.multi_processor_count - VOCAB_SIZEUM_PROGRAMS = batch_size - BLOCK_SIZE = 8192 - SIGMA = 2 # Top 0.03 outliers - Maybe dynamically adjust based on K? - VOCAB_SIZEUM_WARPS = 16 - VOCAB_SIZEUM_STAGES = 3 - probs = torch.full( - (VOCAB_SIZEUM_PROGRAMS, vocab_size), -float("inf"), device=logits.device - ) - debug_tensor = torch.full( - (VOCAB_SIZEUM_PROGRAMS, BLOCK_SIZE), -12.0, device=logits.device - ) - if k is not None and p is None: - _topk_kernel[(VOCAB_SIZEUM_PROGRAMS,)]( - logits, - probs, - debug_tensor, - k, - batch_size, - SIGMA, - vocab_size, - BLOCK_SIZE, - num_warps=VOCAB_SIZEUM_WARPS, - num_stages=VOCAB_SIZEUM_STAGES, - ) - - # print(f"debug_tensor: {debug_tensor[:, :13]}") - elif k is None and p is not None: - probs_2 = torch.full_like(probs, -float("inf"), device=logits.device) - _topp_kernel[(VOCAB_SIZEUM_PROGRAMS,)]( - logits, - probs, - probs_2, - p, - batch_size, - SIGMA, - vocab_size, - BLOCK_SIZE, - num_warps=VOCAB_SIZEUM_WARPS, - num_stages=VOCAB_SIZEUM_STAGES, - ) - elif k is not None and p is not None: - _topk_topp_kernel[(VOCAB_SIZEUM_PROGRAMS,)]( - logits, - probs, - k, - p, - batch_size, - SIGMA, - vocab_size, - BLOCK_SIZE, - num_warps=VOCAB_SIZEUM_WARPS, - num_stages=VOCAB_SIZEUM_STAGES, - ) - return logits - - -@triton.jit -def _topk_kernel( - LOGITS, - PROBS, - DEBUG_TENSOR, - K, - B, - SIGMA: tl.constexpr, - VOCAB_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE - for row_id in tl.range(pid, B, num_programs): - k = tl.load(K + row_id) - if k != VOCAB_SIZE: # All tokens are valid - # THERE IS VOCAB_SIZEO DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET FOR THIS TOP-K - # CURREVOCAB_SIZET IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE IVOCAB_SIZECLUDES ALL DUPLICATE LOGITS, - # WHICH MAY RETURVOCAB_SIZE MORE THAVOCAB_SIZE K LOGITS, - # FOLLOWIVOCAB_SIZEG THE IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE in apply_top_k_only(). - # IF YOU VOCAB_SIZEEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P - # IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE AVOCAB_SIZED IMPLEMEVOCAB_SIZET THE DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET - # USIVOCAB_SIZEG THE FORCE_REMOVE_LOGIT VARIABLE - - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - PROBS_ROW = PROBS + row_id * VOCAB_SIZE - DEBUG_TENSOR_ROW = DEBUG_TENSOR + row_id * BLOCK_SIZE - search_addr = LOGITS_ROW - search_range = VOCAB_SIZE - search_iters = NUM_TILES - - k_pivot = -float("inf") - max_logit = -float("inf") - min_logit = float("inf") - - # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - num_valid = tl.sum(mask_n) - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / num_valid - sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_valid - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - outlier_pivot = avg_logit + SIGMA * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - - sum_logit = tl.sum(logits_blk) - min_logit_value = tl.min(logits_blk) - max_logit_value = tl.max(logits_blk) - - # First pass: compute max and min logits and gather outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) - - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 - ) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) - - tl.store(DEBUG_TENSOR_ROW + 6, num_outliers) - tl.store(DEBUG_TENSOR_ROW + 7, max_logit) - tl.store(DEBUG_TENSOR_ROW + 8, min_logit) - - max_range = max_logit - min_range = min_logit - if num_outliers > k: - max_range = max_logit - min_range = outlier_pivot - search_addr = PROBS_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 - ) - - tl.store(DEBUG_TENSOR_ROW + 9, max_range) - tl.store(DEBUG_TENSOR_ROW + 10, min_range) - - # Second passes: Quaternary search for pivots (nlog_4(n)) - num_iters = 0 - while k_pivot == -float("inf") and num_iters < 32: - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) - - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - - # Check if any of the pivots are equal to k - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we update the range - elif k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: - min_range = k_pivot_1 - elif k_pivots_num_0 > k: - min_range = k_pivot_0 - if k_pivots_num_0 < k: - max_range = k_pivot_0 - elif k_pivots_num_1 < k: - max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - - num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-16: - k_pivot = k_pivot_0 - - tl.store(DEBUG_TENSOR_ROW + 11, num_iters) - tl.store(DEBUG_TENSOR_ROW + 12, k_pivot) - num_masked = tl.zeros((), dtype=tl.uint32) - - # Third pass: Apply top-k mask - if k_pivot != -float("inf"): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - mask = logits_blk > k_pivot - logits_blk = tl.where(mask, logits_blk, -float("inf")) - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) - num_masked += tl.sum(mask) - - tl.store(DEBUG_TENSOR_ROW + 13, num_masked) - offs_n = tl.arange(0, BLOCK_SIZE) - - -@triton.jit -def _topp_kernel( - LOGITS, - PROBS, - PROBS_2, - P, - B, - SIGMA: tl.constexpr, - VOCAB_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - for row_id in tl.range(pid, B, num_programs): - p = tl.load(P + row_id) - if p != 1.0: # All tokens are valid - p_pivot = -float("inf") - - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - PROBS_ROW = PROBS + pid * VOCAB_SIZE - PROBS_2_ROW = PROBS_2 + pid * VOCAB_SIZE - - search_addr = PROBS_ROW - search_range = VOCAB_SIZE - search_iters = NUM_TILES - - max_logit = -float("inf") - min_logit = float("inf") - - # The Pytorch version removes the earlier duplicates - # if there are more than one duplicates - force_remove_logit = -float("inf") - num_force_remove = tl.zeros((), dtype=tl.uint32) - - # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE - # OR all logits are the same - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / VOCAB_SIZE - sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - outlier_pivot = avg_logit + SIGMA * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - sum_outlier_probs = 0.0 - - sum_exp_logits = 0.0 - - # First pass: compute max and min logits - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - - # Second pass: Calculate exp logits and sum - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - - probs_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - - outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits - - # Third pass: Calculate probs and get outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - probs_blk = probs_blk / sum_exp_logits - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - - outlier_mask = (probs_blk > outlier_prob) & mask_n - sum_outlier_probs += tl.sum(outlier_mask * probs_blk) - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 - ) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) - - max_range = tl.exp(max_logit - max_logit) / sum_exp_logits - min_range = tl.exp(min_logit - max_logit) / sum_exp_logits - - if sum_outlier_probs > p: - min_range = outlier_prob - search_addr = PROBS_2_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 - ) - - second_max_logit = -float("inf") - - num_iters = 0 - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - # Fifth passes: Search for p_pivot (2log_2(n)) - while p_pivot == -float("inf") and num_iters < 32: - p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range - p_pivots_sum_0 = 0.0 - - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) - - # Check if any of the pivots are equal to k - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - else: - min_range = p_pivot_0 - else: - max_range = p_pivot_0 - - num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: - p_pivot = p_pivot_0 - - # At least one value should be greater than p_pivot - if p_pivot >= max_logit: - p_pivot = second_max_logit - elif num_min_larger_0 > 1: - # Force remove duplicates (p_pivot is made to include all - # duplicates if it falls on the duplicates) - num_force_remove = tl.cast( - (p_pivots_sum_0 - p) / min_larger_0, tl.uint32 - ) - force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit - - p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - - # Sixth pass: Apply mask - current_num_force_remove = tl.zeros((), dtype=tl.uint32) - if p_pivot != -float("inf"): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) - - if force_remove_logit != -float("inf"): - # Force remove duplicates - tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) - force_remove_mask = ( - tl.abs(logits_blk - force_remove_logit) < tolerance - ) - force_remove_count = ( - tl.cumsum(force_remove_mask) + current_num_force_remove - ) - force_remove_count_mask = force_remove_count <= num_force_remove - force_remove_mask = force_remove_count_mask & force_remove_mask - logits_blk = tl.where( - force_remove_mask, -float("inf"), logits_blk - ) - current_num_force_remove = tl.max(force_remove_count) - - logits_blk = tl.where( - logits_blk > p_pivot, logits_blk, -float("inf") - ) - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) - - -@triton.jit -def _topk_topp_kernel( - LOGITS, - PROBS, - K, - P, - B, - SIGMA: tl.constexpr, - VOCAB_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - for row_id in tl.range(pid, B, num_programs): - k_pivot = -float("inf") - p_pivot = -float("inf") - - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - PROBS_ROW = PROBS + pid * VOCAB_SIZE - - search_addr = LOGITS_ROW - search_range = VOCAB_SIZE - search_iters = NUM_TILES - - max_logit = -float("inf") - min_logit = float("inf") - avg_logit = -float("inf") - - # The Pytorch version removes the earlier duplicates - # if there are more than one duplicates - force_remove_logit = -float("inf") - num_force_remove = tl.zeros((), dtype=tl.uint32) - - # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / VOCAB_SIZE - sq_avg_logit = tl.sum(logits_blk * logits_blk) / VOCAB_SIZE - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - outlier_pivot = avg_logit + SIGMA * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - # First pass: compute max and min logits and gather outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) - - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 - ) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_ROW + write_pos, logits_blk, mask=outlier_mask) - - ############### START OF TOP-K CODE ############### - k = tl.load(K + row_id) - max_range = max_logit - min_range = min_logit - if num_outliers > k: - max_range = max_logit - min_range = outlier_pivot - search_addr = PROBS_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 - ) - - if k != VOCAB_SIZE: # All tokens are valid - # THERE IS VOCAB_SIZEO DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET FOR THIS TOP-K - # CURREVOCAB_SIZET IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE IVOCAB_SIZECLUDES ALL DUPLICATE LOGITS, - # WHICH MAY RETURVOCAB_SIZE MORE THAVOCAB_SIZE K LOGITS, - # FOLLOWIVOCAB_SIZEG THE IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE in apply_top_k_only(). - # IF YOU VOCAB_SIZEEED EXACTLY K LOGITS, PLEASE REFER TO THE TOP-P - # IMPLEMEVOCAB_SIZETATIOVOCAB_SIZE AVOCAB_SIZED IMPLEMEVOCAB_SIZET THE DUPLICATE LOGIT MAVOCAB_SIZEAGEMEVOCAB_SIZET - # USIVOCAB_SIZEG THE FORCE_REMOVE_LOGIT VARIABLE. - - # Second passes: Quaternary search for pivots (nlog_4(n)) - num_iters = 0 - while k_pivot == -float("inf") and num_iters < 18: - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) - - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - - # Check if any of the pivots are equal to k - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we updatae the range - elif k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: - min_range = k_pivot_1 - elif k_pivots_num_0 > k: - min_range = k_pivot_0 - if k_pivots_num_0 < k: - max_range = k_pivot_0 - elif k_pivots_num_1 < k: - max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - - num_iters += 1 - if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-8: - k_pivot = k_pivot_0 - - ############### EVOCAB_SIZED OF TOP-K CODE ############### - - ############### START OF TOP-P CODE ############### - - p = tl.load(P + row_id) - if p != 1.0: # All tokens are valid - second_max_logit = -float("inf") - max_probs = 0.0 - min_probs = 1.0 - sum_exp_logits = 0.0 - - # Third pass: Compute exp logits and sum - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) - probs_blk = tl.where(probs_blk > k_pivot, probs_blk, -float("inf")) - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - - second_max_mask = probs_blk * (probs_blk < max_probs) - second_max_logit = tl.maximum(second_max_logit, tl.max(second_max_mask)) - - # Fourth pass: Compute probs (softmax) - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n) - probs_blk = probs_blk / sum_exp_logits - min_blk = tl.where(mask_n, probs_blk, 1.0) - min_probs = tl.minimum(min_probs, tl.min(min_blk)) - max_blk = tl.where(mask_n, probs_blk, 0.0) - max_probs = tl.maximum(max_probs, tl.max(max_blk)) - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - - max_range = max_probs - min_range = min_probs - - num_iters = 0 - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - # Fifth passes: Search for p_pivot (2log_2(n)) - while p_pivot == -float("inf") and num_iters < 32: - p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range - p_pivots_sum_0 = 0.0 - - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) - - # Check if any of the pivots are equal to k - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - else: - min_range = p_pivot_0 - else: - max_range = p_pivot_0 - - num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: - p_pivot = p_pivot_0 - - # At least one value should be greater than p_pivot - if p_pivot >= max_logit: - p_pivot = second_max_logit - elif num_min_larger_0 > 1: - # Force remove duplicates (p_pivot is made to include all - # duplicates if it falls on the duplicates) - num_force_remove = tl.cast( - (p_pivots_sum_0 - p) / min_larger_0, tl.uint32 - ) - force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit - - p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - - ############### EVOCAB_SIZED OF TOP-P CODE ############### - - # Sixth pass: Apply mask - pivot = tl.maximum(k_pivot, p_pivot) - current_num_force_remove = tl.zeros((), dtype=tl.uint32) - if pivot != -float("inf"): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) - - if force_remove_logit != -float("inf"): - # Force remove duplicates - tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) - force_remove_mask = ( - tl.abs(logits_blk - force_remove_logit) < tolerance - ) - force_remove_count = ( - tl.cumsum(force_remove_mask) + current_num_force_remove - ) - force_remove_count_mask = force_remove_count <= num_force_remove - force_remove_mask = force_remove_count_mask & force_remove_mask - logits_blk = tl.where(force_remove_mask, -float("inf"), logits_blk) - current_num_force_remove = tl.max(force_remove_count) - - logits_blk = tl.where(logits_blk > pivot, logits_blk, -float("inf")) - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) - - def apply_top_k_only( logits: torch.Tensor, k: torch.Tensor, @@ -1055,481 +341,446 @@ def flashinfer_sample( return next_token_ids.view(-1) +# fmt: off +_PERCENTILE_TO_STD_TABLE = [ + 2.576, 2.326, 2.054, 1.881, 1.751, + 1.645, 1.555, 1.476, 1.405, 1.341, + 1.282, 1.227, 1.175, 1.126, 1.080, + 1.036, 0.994, 0.954, 0.915, 0.878, + 0.842, 0.806, 0.772, 0.739, 0.706, + 0.674, 0.643, 0.613, 0.583, 0.553, + 0.524, 0.496, 0.468, 0.440, 0.412, + 0.385, 0.358, 0.332, 0.305, 0.279, + 0.253, 0.228, 0.202, 0.176, 0.151, + 0.126, 0.100, 0.075, 0.050, 0.025, + 0.000, -0.025, -0.050, -0.075, -0.100, + -0.126, -0.151, -0.176, -0.202, -0.228, + -0.253, -0.279, -0.305, -0.332, -0.358, + -0.385, -0.412, -0.440, -0.468, -0.496, + -0.524, -0.553, -0.583, -0.613, -0.643, + -0.674, -0.706, -0.739, -0.772, -0.806, + -0.842, -0.878, -0.915, -0.954, -0.994, + -1.036, -1.080, -1.126, -1.175, -1.227, + -1.282, -1.341, -1.405, -1.476, -1.555, + -1.645, -1.751, -1.881, -2.054, -2.326 +] +# fmt: on + + @triton.jit -def _topp_kernel_sorted( +def _topk_triton_kernel( LOGITS, - PROBS, - PROBS_2, - P, - B, - SIGMA: tl.constexpr, + OUTPUT, + PERCENTILE_TO_STD_TABLE, + K, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - """Modified top-p kernel with sort-equivalent tie-breaking - and re-enabled outlier optimization. - """ + row_id = tl.program_id(0) NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - - for row_id in tl.range(pid, B, num_programs): - p = tl.load(P + row_id) - if p != 1.0: # All tokens are valid - p_pivot = -float("inf") - - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - PROBS_ROW = PROBS + pid * VOCAB_SIZE - PROBS_2_ROW = PROBS_2 + pid * VOCAB_SIZE # <-- RE-ADDED + k = tl.load(K + row_id) - # Default search params - search_addr = PROBS_ROW - search_range = VOCAB_SIZE - search_iters = NUM_TILES + if k != VOCAB_SIZE: + # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K KERNEL + # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, + # WHICH MAY RETURN MORE THAN K LOGITS. + # THIS FOLLOWS THE IMPLEMENTATION IN apply_top_k_only(). - max_logit = -float("inf") - min_logit = float("inf") + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + OUTPUT_ROW = OUTPUT + row_id * VOCAB_SIZE + search_addr = LOGITS_ROW + search_range = VOCAB_SIZE + search_iters = NUM_TILES - force_remove_logit = -float("inf") - num_force_remove = tl.zeros((), dtype=tl.uint32) + k_pivot = -float("inf") + max_logit = -float("inf") + min_logit = float("inf") - # --- ZEROTH PASS (RE-ADDED) --- - # Compute *exact* avg and std - sum_logits = 0.0 - sum_sq_logits = 0.0 - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=0.0) - sum_logits += tl.sum(tl.where(mask_n, logits_blk, 0.0)) - sum_sq_logits += tl.sum(tl.where(mask_n, logits_blk * logits_blk, 0.0)) + # Zeroth pass: Compute avg and std from a sample block + # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + num_valid = tl.sum(mask_n) + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / num_valid + sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_valid + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - avg_logit = sum_logits / VOCAB_SIZE - sq_avg_logit = sum_sq_logits / VOCAB_SIZE - std_logit = tl.sqrt(tl.maximum(0.0, sq_avg_logit - avg_logit * avg_logit)) - outlier_pivot = avg_logit + SIGMA * std_logit # <-- RE-ADDED - num_outliers = tl.zeros((), dtype=tl.uint32) # <-- RE-ADDED - sum_outlier_probs = 0.0 # <-- RE-ADDED + percentile = tl.cast(k * 1.2 / VOCAB_SIZE * 100, tl.uint32) + 1 + percentile = tl.minimum(percentile, 99) + sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) + outlier_pivot = avg_logit + sigma * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) - sum_exp_logits = 0.0 + sum_logit = tl.sum(logits_blk) + min_logit_value = tl.min(logits_blk) + max_logit_value = tl.max(logits_blk) - # First pass: compute max and min logits - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) # Use -inf - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + # First pass: compute max and min logits and gather outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) - # Second pass: Calculate exp logits and sum - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = (logits_blk > outlier_pivot) & mask_n + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) + num_outliers += num_blk_outliers + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(OUTPUT_ROW + write_pos, logits_blk, mask=outlier_mask) - probs_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) + max_range = max_logit + min_range = min_logit + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_addr = OUTPUT_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 + ) - # --- OUTLIER_PROB (RE-ADDED) --- - outlier_prob = tl.exp(outlier_pivot - max_logit) / sum_exp_logits + # Second passes: Quaternary search for pivots (nlog_4(n)) + num_iters = 0 + while k_pivot == -float("inf") and num_iters < 32: + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - # Third pass: Calculate final probs AVOCAB_SIZED get outliers for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) - probs_blk = probs_blk / sum_exp_logits - tl.store(PROBS_ROW + offs_n, probs_blk, mask=mask_n) - - # --- OUTLIER MASKIVOCAB_SIZEG LOGIC (RE-ADDED) --- - outlier_mask = (probs_blk > outlier_prob) & mask_n - sum_outlier_probs += tl.sum(outlier_mask * probs_blk) - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 - ) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(PROBS_2_ROW + write_pos, probs_blk, mask=outlier_mask) - - max_range = tl.exp(max_logit - max_logit) / sum_exp_logits - min_range = tl.exp(min_logit - max_logit) / sum_exp_logits - - if sum_outlier_probs > p: - min_range = outlier_prob - search_addr = PROBS_2_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 + logits_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") ) - second_max_logit = -float("inf") - num_iters = 0 - p_pivots_sum_0 = 0.0 # --> total prob including all equivalent min - min_larger_0 = 1.0 # --> prob of tie-breaking min - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - # Binary search for p_pivot - while p_pivot == -float("inf") and num_iters < 32: - p_pivot_0 = (max_range - min_range) * 1.0 / 2.0 + min_range - p_pivots_sum_0 = 0.0 - - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - probs_blk = tl.load(search_addr + offs_n, mask=mask_n, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-7) - - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - else: - min_range = p_pivot_0 - else: - max_range = p_pivot_0 - - num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-8: - p_pivot = p_pivot_0 - - if p_pivot >= max_logit: - p_pivot = second_max_logit - elif num_min_larger_0 > 1: - num_force_remove = tl.cast( - (p_pivots_sum_0 - p) / min_larger_0, tl.uint32 - ) # --> number of probs to be removed - force_remove_logit = tl.log(min_larger_0 * sum_exp_logits) + max_logit - - p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - - # Apply mask with (non-sort-equivalent) tie-breaking - current_num_removed = tl.zeros((), dtype=tl.uint32) - if p_pivot != -float("inf"): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) - - if force_remove_logit != -float("inf"): - # Match PyTorch's non-sort-equivalent tie-breaking - tolerance = 1e-5 * tl.maximum(1.0, tl.abs(force_remove_logit)) - is_tie = tl.abs(logits_blk - force_remove_logit) < tolerance - tie_position = tl.cumsum(is_tie) - 1 + current_num_removed - should_remove = is_tie & (tie_position < num_force_remove) - logits_blk = tl.where(should_remove, -float("inf"), logits_blk) - current_num_removed += tl.sum(is_tie) - - # Standard threshold masking - tolerance = 1e-6 * tl.maximum(1.0, tl.abs(p_pivot)) - logits_blk = tl.where( - logits_blk >= (p_pivot - tolerance), logits_blk, -float("inf") - ) - - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + + # Check if any of the pivots are equal to k + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we update the range + elif k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 + + num_iters += 1 + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-16: + k_pivot = k_pivot_0 + + # Third pass: Apply top-k mask + if k_pivot != -float("inf"): + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) + mask = logits_blk > k_pivot + logits_blk = tl.where(mask, logits_blk, -float("inf")) + tl.store(OUTPUT_ROW + offs_n, logits_blk, mask=mask_n) -def apply_top_p_sorted_equivalent( +def apply_top_k_only_triton( logits: torch.Tensor, - p: torch.Tensor, - sigma: float = 3.0, + k: torch.Tensor, ) -> torch.Tensor: - """Apply top-p using binary search (no sort!) with sort-equivalent results. - - Args: - logits: [B, VOCAB_SIZE] logits tensor - p: [B] top-p thresholds - sigma: Standard deviation multiplier for outlier detection - Returns: - Modified logits, equivalent to sorted top-p version """ - B, VOCAB_SIZE = logits.shape - device = logits.device + Apply top-k mask to the logits using Triton. - BLOCK_SIZE = triton.next_power_of_2(min(VOCAB_SIZE, 1024)) - num_warps = 4 if BLOCK_SIZE < 2048 else 8 + The logits tensor will be updated out-of-place. + """ - probs = torch.empty((B, VOCAB_SIZE), device=device, dtype=torch.float32) - probs_2 = torch.empty((B, VOCAB_SIZE), device=device, dtype=torch.float32) + batch_size, vocab_size = logits.shape + NUM_PROGRAMS = batch_size # Non-persistent kernel + BLOCK_SIZE = 8192 + NUM_WARPS = 16 + NUM_STAGES = 3 + output = torch.full(logits.shape, -float("inf"), device=logits.device) + PERCENTILE_TO_STD_TABLE = torch.tensor( + _PERCENTILE_TO_STD_TABLE, device=logits.device + ) - grid = (B,) - _topp_kernel_sorted[grid]( + _topk_triton_kernel[(NUM_PROGRAMS,)]( logits, - probs, - probs_2, - p, - B, - SIGMA=sigma, - VOCAB_SIZE=VOCAB_SIZE, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, + output, + PERCENTILE_TO_STD_TABLE, + k, + vocab_size, + BLOCK_SIZE, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, ) - return logits - + return output -def apply_top_k_top_p_test( - logits: torch.Tensor, - k: torch.Tensor | None, - p: torch.Tensor | None, -) -> torch.Tensor: - """Optimized implementation combining torch.topk and binary search kernel.""" - if p is None: - if k is None: - return logits - return apply_top_k_only(logits, k) - # Apply top-k filter first if needed - if k is not None: - logits = apply_top_k_only(logits, k) - # Apply top-p using binary search (no sort!) - return apply_top_p_sorted_equivalent(logits, p) - - -# -------------------------------------------------------------------------- @triton.jit def top_p_pivot_filter( - LOGITS, - PROBS, - PROBS_IDX, + DEBUG_TENSOR, + LOGITS_MAYBE_K_MASKED, K_FILTER: tl.int32, - OUTPUT_LOGITS, - OUTPUT_INDICES, - B, # --> batch size - SIGMA: tl.constexpr, + BUFFER, + BATCH_SIZE, + SUM_FILTERED_PROBS, + FILTERED_LOGITS, + FILTERED_INDICES, + FILTERED_PROBS, + PERCENTILE_TO_STD_TABLE, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE - pid = tl.program_id(0) num_programs = tl.num_programs(0) + k = K_FILTER - for row_id in tl.range(pid, B, num_programs): - k = K_FILTER - if k <= VOCAB_SIZE: - k_pivot = -float("inf") - - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - PROBS_ROW = PROBS + row_id * VOCAB_SIZE - PROBS_IDX_ROW = PROBS_IDX + row_id * VOCAB_SIZE - OUTPUT_LOGITS_ROW = OUTPUT_LOGITS + row_id * K_FILTER - OUTPUT_INDICES_ROW = OUTPUT_INDICES + row_id * K_FILTER - - search_addr = LOGITS_ROW - search_range = VOCAB_SIZE - search_iters = NUM_TILES - - max_logit = -float("inf") - min_logit = float("inf") - - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - num_mask = tl.sum(mask_n) - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - probs_blk = tl.load(PROBS_ROW + offs, mask=mask_n, other=0.0) - valid_count = tl.sum(mask_n, dtype=tl.float32) - avg_logit = tl.sum(logits_blk) / num_mask # re-check - sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_mask # re-check - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - outlier_pivot = avg_logit + SIGMA * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) + for row_id in tl.range(pid, BATCH_SIZE, num_programs): + k_pivot = -float("inf") - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) - probs_blk = tl.load(PROBS_ROW + offs_n, mask=mask_n, other=0.0) + LOGITS_ROW = LOGITS_MAYBE_K_MASKED + row_id * VOCAB_SIZE + BUFFER_ROW = BUFFER + pid * VOCAB_SIZE + FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * K_FILTER + FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * K_FILTER + FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * K_FILTER - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + search_addr = LOGITS_ROW + search_range = VOCAB_SIZE + search_iters = NUM_TILES - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 - ) - num_outliers += num_blk_outliers + max_logit = -float("inf") + min_logit = float("inf") + + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + num_mask = tl.sum(mask_n) + logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk) / num_mask + sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_mask + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + + percentile = tl.cast(k * 1.2 / VOCAB_SIZE * 100, tl.uint32) + 1 + percentile = tl.minimum(percentile, 99) + sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) + outlier_pivot = avg_logit + sigma * std_logit + num_outliers = tl.zeros((), dtype=tl.uint32) + + # First pass: compute max and min logits and gather outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - write_idx = tl.where(outlier_mask, cumulative_pos, 0) - tl.store(PROBS_ROW + write_idx, logits_blk, mask=outlier_mask) - tl.store(PROBS_IDX_ROW + write_idx, offs_n, mask=outlier_mask) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = (logits_blk > outlier_pivot) & mask_n + num_blk_outliers = tl.sum(outlier_mask) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) + num_outliers += num_blk_outliers + + write_idx = tl.where(outlier_mask, cumulative_pos, 0) + tl.store(BUFFER_ROW + write_idx, logits_blk, mask=outlier_mask) + + max_range = max_logit + min_range = min_logit + if num_outliers > k: max_range = max_logit - min_range = min_logit - if num_outliers > k: - max_range = max_logit - min_range = outlier_pivot - search_addr = PROBS_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 - ) + min_range = outlier_pivot + search_addr = BUFFER_ROW + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 + ) - # Second passes: Quaternary search for pivots (nlog_4(n)) - num_iters = 0 - while k_pivot == -float("inf") and num_iters < 18: - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) + # Second passes: Quaternary search for pivots (nlog_4(n)) + num_iters = 0 + while k_pivot == -float("inf") and num_iters < 32: + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - - # Check if any of the pivots are equal to k - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we update the range - elif k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: - min_range = k_pivot_1 - elif k_pivots_num_0 > k: - min_range = k_pivot_0 - if k_pivots_num_0 < k: - max_range = k_pivot_0 - elif k_pivots_num_1 < k: - max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - - num_iters += 1 - if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-8: - k_pivot = k_pivot_0 - # ============== Third pass : Apply top-k mask ================ - write_pos = tl.zeros((), dtype=tl.int32) - # if k_pivot != -float('inf'): - for i in range(0, NUM_TILES): + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE + mask_n = offs_n < search_range logits_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + search_addr + offs_n, mask=mask_n, other=-float("inf") ) - keep_mask = (logits_blk >= k_pivot) & mask_n - n_kept = tl.sum(keep_mask, dtype=tl.int32) - cpos = tl.cumsum(keep_mask) - 1 + write_pos - final_mask = keep_mask & (cpos < k) - write_idx = tl.where(final_mask, cpos, 0) - tl.store(OUTPUT_LOGITS_ROW + write_idx, logits_blk, mask=final_mask) - tl.store(OUTPUT_INDICES_ROW + write_idx, offs_n, mask=final_mask) - write_pos += tl.sum(final_mask, dtype=tl.int32) + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + + # Check if any of the pivots are equal to k + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we update the range + elif k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 + + num_iters += 1 + if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-16: + k_pivot = k_pivot_0 + + # Third pass: Calculate exp logits and sum + sum_exp_logits = tl.zeros((), dtype=tl.float32) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + # Fourth pass: Calculate softmax + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + # Fifth pass : Gather filtered values + write_pos = tl.zeros((), dtype=tl.int32) + sum_probs = tl.zeros((), dtype=tl.float32) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + keep_mask = (logits_blk >= k_pivot) & mask_n + n_kept = tl.sum(keep_mask, dtype=tl.int32) + cpos = tl.cumsum(keep_mask) - 1 + write_pos + final_mask = keep_mask & (cpos < k) + write_idx = tl.where(final_mask, cpos, 0) -def apply_top_k_top_p_filtered( + # Gather filtered values + tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=final_mask) + tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=final_mask) + tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=final_mask) + + sum_probs += tl.sum(probs_blk * final_mask) + write_pos += tl.sum(final_mask, dtype=tl.int32) + tl.store(SUM_FILTERED_PROBS + row_id, sum_probs) + + +def apply_top_p_filtered( logits: torch.Tensor, + logits_maybe_k_masked: torch.Tensor, k: torch.Tensor, p: torch.Tensor, ) -> torch.Tensor: """ Applies top p using pivot based filtering """ + batch_size, vocab_size = logits.shape - BLOCK_SIZE = 2048 - SIGMA = 2.15 + BLOCK_SIZE = 4096 + device_prop = torch.cuda.get_device_properties(logits.device) + NUM_PROGRAMS = device_prop.multi_processor_count # Persistent kernel + buffer = torch.empty( + (NUM_PROGRAMS, vocab_size), device=logits.device, dtype=torch.float32 + ) NUM_WARPS = 16 NUM_STAGES = 3 - # ================= to find the k filter value ==================== - k_filter = max(int(vocab_size * 1 / 32), k.max() * 2) - if k_filter > vocab_size / 2: - print(f"k_filter too big: {k_filter}") - return apply_top_k_top_p(logits, k, p) - + k_filter = int(vocab_size * 1 / 32) filtered_logits = torch.full( (batch_size, k_filter), -float("inf"), device=logits.device ) filtered_indices = torch.full( (batch_size, k_filter), 0, dtype=torch.int32, device=logits.device ) + filtered_probs = torch.full( + (batch_size, k_filter), -float("inf"), device=logits.device + ) + sum_filtered_probs = torch.zeros( + (batch_size,), device=logits.device, dtype=torch.float32 + ) + + PERCENTILE_TO_STD_TABLE = torch.tensor( + _PERCENTILE_TO_STD_TABLE, device=logits.device + ) - probs = torch.empty( - (batch_size, vocab_size), device=logits.device, dtype=torch.float32 + debug_tensor = torch.full( + (batch_size, vocab_size), -float("inf"), device=logits.device ) - probs_idx = torch.empty_like(probs, dtype=torch.int32) - - grid = (batch_size,) - top_p_pivot_filter[grid]( - logits, # --> input - probs, # initial filtered - probs_idx, # initial filtered index - k_filter, # --> scalar - filtered_logits, # --> output, filtered - filtered_indices, # --> filtered logits indices + + top_p_pivot_filter[(NUM_PROGRAMS,)]( + debug_tensor, + logits_maybe_k_masked, + k_filter, + buffer, batch_size, - SIGMA=SIGMA, + sum_filtered_probs, + filtered_logits, + filtered_indices, + filtered_probs, + PERCENTILE_TO_STD_TABLE, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, ) - target_logit = filtered_logits - - if p is None and k is not None: - filtered_logits = apply_top_k_only(filtered_logits, k) - logits.fill_(-float("inf")) - logits.scatter_(dim=1, index=filtered_indices, src=filtered_logits) - return logits - assert False - - if p is not None: - # this kernel outputs filtered_logits and filtered_indices of shape (batch_size, k_filter) - logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) - logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) - - logits_softmax = logits.softmax(dim=-1) - sorted_probs = torch.gather(logits_softmax, -1, logits_sort_indices) + # filtered_logits = logits.clone().detach() + # filtered_indices = torch.arange(0, vocab_size, device=logits.device).unsqueeze(0).expand(batch_size, vocab_size) + # filtered_probs = torch.softmax(filtered_logits, dim=-1) + # sum_filtered_probs = torch.sum(filtered_probs, dim=-1) - sum_probs = sorted_probs.sum(dim=-1) - - if torch.any(sum_probs < p): - return apply_top_k_top_p(logits, k, p) + if torch.any(sum_filtered_probs < p): + return apply_top_k_top_p(logits, k, p) - if torch.any(sum_probs < p): - print("edge case --> fall back !") - # assert False - return apply_top_k_top_p(logits, k=None, p=p) + logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) + logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) + sorted_probs = torch.gather(filtered_probs, -1, sort_indices) probs_sum = torch.cumsum(sorted_probs, dim=-1) - sum_non_outliers = (1.0 - sum_probs).unsqueeze(-1) + sum_non_outliers = (1.0 - sum_filtered_probs).unsqueeze(-1) probs_sum = probs_sum + sum_non_outliers top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=-1)) top_p_mask[:, -1] = False @@ -1540,7 +791,7 @@ def apply_top_k_top_p_filtered( return logits -def apply_top_k_top_p_test2( +def apply_top_k_top_p_triton( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, @@ -1552,9 +803,11 @@ def apply_top_k_top_p_test2( if k is None and p is None: return logits elif p is None and k is not None: - return apply_top_k_top_p_triton(logits, k, p) + return apply_top_k_only_triton(logits, k) elif k is None and p is not None: - return apply_top_k_top_p_filtered(logits, k, p) + # k must be supplied for the fallback path + return apply_top_p_filtered(logits, logits, k, p) else: - logits_k = apply_top_k_top_p_triton(logits, k, p=None) - return apply_top_k_top_p_filtered(logits, k, p) + logits_k_masked = apply_top_k_only_triton(logits, k) + # Original logits and k must be supplied for the fallback path + return apply_top_p_filtered(logits, logits_k_masked, k, p) From 2893ed5428d073b01a9a843eab35c8f579a1b7a4 Mon Sep 17 00:00:00 2001 From: js_park Date: Fri, 14 Nov 2025 23:16:57 -0800 Subject: [PATCH 50/99] Everything correct but slow Signed-off-by: js_park --- compare.py | 27 +++++++++++++------------ vllm/v1/sample/ops/topk_topp_sampler.py | 16 ++++++--------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/compare.py b/compare.py index 4fae145ee54c..e1e90ed4f96c 100644 --- a/compare.py +++ b/compare.py @@ -3,17 +3,18 @@ from datetime import datetime from itertools import product -# torch.manual_seed(42) -# torch.cuda.manual_seed(42) -# torch.cuda.manual_seed_all(42) -# import random -# import numpy as np -# random.seed(42) -# np.random.seed(42) -# torch.backends.cudnn.deterministic = True -# torch.backends.cudnn.benchmark = False -import regex as re import torch +torch.manual_seed(42) +torch.cuda.manual_seed(42) +torch.cuda.manual_seed_all(42) +import random +import numpy as np +random.seed(42) +np.random.seed(42) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +import regex as re + from vllm.v1.sample.ops.topk_topp_sampler import ( apply_top_k_top_p, @@ -76,7 +77,7 @@ def test_accuracy(logits, k, p, func_list): error_cols = torch.unique(error_cols) num_error_cols = error_cols.shape[0] print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", log_file) - print_to_log(f"num_error_cols: {num_error_cols}", log_file) + print_to_log(f"num_error_cols: {num_error_cols} - {error_cols}", log_file) row_to_show = 5 if num_error_rows > 5 else num_error_rows logits_to_show = torch.sort( output_logits[error_rows], descending=True @@ -122,8 +123,8 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): vocab_size_list = [4096, 16384, 65536, 128000, 262144] p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] # p_list = [None] - # k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] - k_list = [None] + k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + # k_list = [None] func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] log_file = f"triton_topk_topp_test_{date_str}.log" diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 14e222ed0a7a..168c84241eb3 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -667,7 +667,7 @@ def top_p_pivot_filter( sum_exp_logits = tl.zeros((), dtype=tl.float32) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range + mask_n = offs_n < VOCAB_SIZE probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) @@ -677,7 +677,7 @@ def top_p_pivot_filter( # Fourth pass: Calculate softmax for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range + mask_n = offs_n < VOCAB_SIZE probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) probs_blk = probs_blk / sum_exp_logits tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) @@ -767,18 +767,13 @@ def apply_top_p_filtered( num_stages=NUM_STAGES, ) - # filtered_logits = logits.clone().detach() - # filtered_indices = torch.arange(0, vocab_size, device=logits.device).unsqueeze(0).expand(batch_size, vocab_size) - # filtered_probs = torch.softmax(filtered_logits, dim=-1) - # sum_filtered_probs = torch.sum(filtered_probs, dim=-1) - - if torch.any(sum_filtered_probs < p): - return apply_top_k_top_p(logits, k, p) - logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) sorted_probs = torch.gather(filtered_probs, -1, sort_indices) + if torch.any(sum_filtered_probs < p): + return apply_top_k_top_p(logits, k, p) + probs_sum = torch.cumsum(sorted_probs, dim=-1) sum_non_outliers = (1.0 - sum_filtered_probs).unsqueeze(-1) probs_sum = probs_sum + sum_non_outliers @@ -789,6 +784,7 @@ def apply_top_p_filtered( logits.fill_(-float("inf")) logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) return logits + def apply_top_k_top_p_triton( From 6e3c8744d81c6c119d9ed0c8abb6cd296cb0c6ae Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 15 Nov 2025 13:53:13 -0800 Subject: [PATCH 51/99] Everything correct but slow Signed-off-by: js_park --- compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compare.py b/compare.py index e1e90ed4f96c..a4d624e007ee 100644 --- a/compare.py +++ b/compare.py @@ -121,9 +121,9 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): batch_size_list = [64, 128, 1024] vocab_size_list = [4096, 16384, 65536, 128000, 262144] - p_list = [None, "RAND", 0.4, 0.7, 0.9, 0.95, 0.99] + p_list = [None, "RAND"] # p_list = [None] - k_list = [None, "RAND", 5, 10, 50, 100, 200, 300, 3000] + k_list = [None, "RAND"] # k_list = [None] func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] From d0f491ee61d223f60e954f9e1e863939e057874a Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 15 Nov 2025 16:04:22 -0800 Subject: [PATCH 52/99] Fast and working correctly Signed-off-by: js_park --- compare.py | 7 +- vllm/v1/sample/ops/topk_topp_sampler.py | 162 ++++++++++++++++-------- 2 files changed, 112 insertions(+), 57 deletions(-) diff --git a/compare.py b/compare.py index a4d624e007ee..b35fa92e3464 100644 --- a/compare.py +++ b/compare.py @@ -4,18 +4,20 @@ from itertools import product import torch + torch.manual_seed(42) torch.cuda.manual_seed(42) torch.cuda.manual_seed_all(42) import random + import numpy as np + random.seed(42) np.random.seed(42) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False import regex as re - from vllm.v1.sample.ops.topk_topp_sampler import ( apply_top_k_top_p, apply_top_k_top_p_triton, @@ -155,7 +157,7 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): logits_list = [("RANDN", logits_randn)] if p == "RAND": - p_tensor = torch.rand((batch_size,), device="cuda") * 0.95 + 0.05 + p_tensor = torch.rand((batch_size,), device="cuda") * 0.98 + 0.01 elif p is not None: p_tensor = torch.full((batch_size,), p, device="cuda") else: @@ -194,7 +196,6 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): f"{dist_generator}, p: {p}, k: {k}", log_file, ) - print_to_log("Test accuracy passed! Now testing speedup...", log_file) time_list = [] for func in func_list: time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 168c84241eb3..f4a02d100dc9 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -392,7 +392,6 @@ def _topk_triton_kernel( search_range = VOCAB_SIZE search_iters = NUM_TILES - k_pivot = -float("inf") max_logit = -float("inf") min_logit = float("inf") @@ -446,7 +445,8 @@ def _topk_triton_kernel( # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 - while k_pivot == -float("inf") and num_iters < 32: + k_pivot = -float("inf") + while k_pivot == -float("inf"): k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range @@ -536,10 +536,11 @@ def apply_top_k_only_triton( @triton.jit -def top_p_pivot_filter( - DEBUG_TENSOR, - LOGITS_MAYBE_K_MASKED, - K_FILTER: tl.int32, +def top_k_top_p_filter( + LOGITS, + DO_TOP_K, + K, + P_FIL, BUFFER, BATCH_SIZE, SUM_FILTERED_PROBS, @@ -553,16 +554,10 @@ def top_p_pivot_filter( NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE pid = tl.program_id(0) num_programs = tl.num_programs(0) - k = K_FILTER for row_id in tl.range(pid, BATCH_SIZE, num_programs): - k_pivot = -float("inf") - - LOGITS_ROW = LOGITS_MAYBE_K_MASKED + row_id * VOCAB_SIZE + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE BUFFER_ROW = BUFFER + pid * VOCAB_SIZE - FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * K_FILTER - FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * K_FILTER - FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * K_FILTER search_addr = LOGITS_ROW search_range = VOCAB_SIZE @@ -580,7 +575,7 @@ def top_p_pivot_filter( sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_mask std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - percentile = tl.cast(k * 1.2 / VOCAB_SIZE * 100, tl.uint32) + 1 + percentile = tl.cast(P_FIL * 1.2 / VOCAB_SIZE * 100, tl.uint32) + 1 percentile = tl.minimum(percentile, 99) sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) outlier_pivot = avg_logit + sigma * std_logit @@ -608,7 +603,7 @@ def top_p_pivot_filter( max_range = max_logit min_range = min_logit - if num_outliers > k: + if num_outliers > P_FIL: max_range = max_logit min_range = outlier_pivot search_addr = BUFFER_ROW @@ -617,16 +612,37 @@ def top_p_pivot_filter( (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 ) + k = tl.load(K + row_id) + k_max_range = max_range + k_min_range = min_range + p_fil_max_range = max_range + p_fil_min_range = min_range + # Second passes: Quaternary search for pivots (nlog_4(n)) + k_pivot = -float("inf") + p_fil_pivot = -float("inf") num_iters = 0 - while k_pivot == -float("inf") and num_iters < 32: - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + while k_pivot == -float("inf") or p_fil_pivot == -float("inf"): + k_pivot_0 = (k_max_range - k_min_range) * 1.0 / 4.0 + k_min_range + k_pivot_1 = (k_max_range - k_min_range) * 2.0 / 4.0 + k_min_range + k_pivot_2 = (k_max_range - k_min_range) * 3.0 / 4.0 + k_min_range k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + p_fil_pivot_0 = ( + p_fil_max_range - p_fil_min_range + ) * 1.0 / 4.0 + p_fil_min_range + p_fil_pivot_1 = ( + p_fil_max_range - p_fil_min_range + ) * 2.0 / 4.0 + p_fil_min_range + p_fil_pivot_2 = ( + p_fil_max_range - p_fil_min_range + ) * 3.0 / 4.0 + p_fil_min_range + p_fil_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + p_fil_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + p_fil_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range @@ -638,6 +654,10 @@ def top_p_pivot_filter( k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + p_fil_pivots_num_0 += tl.sum(logits_blk > p_fil_pivot_0) + p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) + p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) + # Check if any of the pivots are equal to k if k_pivots_num_0 == k: k_pivot = k_pivot_0 @@ -647,29 +667,63 @@ def top_p_pivot_filter( k_pivot = k_pivot_2 # If none of the pivots are equal to k, we update the range elif k_pivots_num_2 > k: - min_range = k_pivot_2 + k_min_range = k_pivot_2 elif k_pivots_num_1 > k: - min_range = k_pivot_1 + k_min_range = k_pivot_1 elif k_pivots_num_0 > k: - min_range = k_pivot_0 + k_min_range = k_pivot_0 if k_pivots_num_0 < k: - max_range = k_pivot_0 + k_max_range = k_pivot_0 elif k_pivots_num_1 < k: - max_range = k_pivot_1 + k_max_range = k_pivot_1 elif k_pivots_num_2 < k: - max_range = k_pivot_2 + k_max_range = k_pivot_2 + + # Check if any of the pivots are equal to P_FIL + if p_fil_pivots_num_0 == P_FIL: + p_fil_pivot = p_fil_pivot_0 + elif p_fil_pivots_num_1 == P_FIL: + p_fil_pivot = p_fil_pivot_1 + elif p_fil_pivots_num_2 == P_FIL: + p_fil_pivot = p_fil_pivot_2 + # If none of the pivots are equal to P_FIL, we update the range + elif p_fil_pivots_num_2 > P_FIL: + p_fil_min_range = p_fil_pivot_2 + elif p_fil_pivots_num_1 > P_FIL: + p_fil_min_range = p_fil_pivot_1 + elif p_fil_pivots_num_0 > P_FIL: + p_fil_min_range = p_fil_pivot_0 + if p_fil_pivots_num_0 < P_FIL: + p_fil_max_range = p_fil_pivot_0 + elif p_fil_pivots_num_1 < P_FIL: + p_fil_max_range = p_fil_pivot_1 + elif p_fil_pivots_num_2 < P_FIL: + p_fil_max_range = p_fil_pivot_2 num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-16: - k_pivot = k_pivot_0 + if num_iters >= 32 or ( + tl.abs(k_min_range - k_max_range) < 1e-16 + and tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 + ): + if k_pivot == -float("inf"): + k_pivot = k_pivot_0 + if p_fil_pivot == -float("inf"): + p_fil_pivot = p_fil_pivot_0 + + # Third pass: Mask top-k, calculate exp logits and sum + if not DO_TOP_K: + k_pivot = -float("inf") - # Third pass: Calculate exp logits and sum sum_exp_logits = tl.zeros((), dtype=tl.float32) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) - probs_blk = probs_blk - max_logit + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + + top_k_mask = logits_blk > k_pivot + logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) + + probs_blk = logits_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) @@ -685,17 +739,23 @@ def top_p_pivot_filter( # Fifth pass : Gather filtered values write_pos = tl.zeros((), dtype=tl.int32) sum_probs = tl.zeros((), dtype=tl.float32) + FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * (P_FIL + 1) + FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * (P_FIL + 1) + FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * (P_FIL + 1) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - keep_mask = (logits_blk >= k_pivot) & mask_n + keep_mask = (logits_blk >= p_fil_pivot) & mask_n n_kept = tl.sum(keep_mask, dtype=tl.int32) cpos = tl.cumsum(keep_mask) - 1 + write_pos - final_mask = keep_mask & (cpos < k) - write_idx = tl.where(final_mask, cpos, 0) + final_mask = keep_mask & (cpos < P_FIL) + write_idx = tl.where(final_mask, cpos, P_FIL) + + top_k_mask = logits_blk > k_pivot + logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) # Gather filtered values tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=final_mask) @@ -709,7 +769,6 @@ def top_p_pivot_filter( def apply_top_p_filtered( logits: torch.Tensor, - logits_maybe_k_masked: torch.Tensor, k: torch.Tensor, p: torch.Tensor, ) -> torch.Tensor: @@ -719,7 +778,7 @@ def apply_top_p_filtered( batch_size, vocab_size = logits.shape - BLOCK_SIZE = 4096 + BLOCK_SIZE = 8192 device_prop = torch.cuda.get_device_properties(logits.device) NUM_PROGRAMS = device_prop.multi_processor_count # Persistent kernel buffer = torch.empty( @@ -728,15 +787,15 @@ def apply_top_p_filtered( NUM_WARPS = 16 NUM_STAGES = 3 - k_filter = int(vocab_size * 1 / 32) + p_filter = k.max().item() if k is not None else int(vocab_size * 1 / 32) filtered_logits = torch.full( - (batch_size, k_filter), -float("inf"), device=logits.device + (batch_size, p_filter + 1), -float("inf"), device=logits.device ) filtered_indices = torch.full( - (batch_size, k_filter), 0, dtype=torch.int32, device=logits.device + (batch_size, p_filter + 1), p_filter, dtype=torch.int64, device=logits.device ) filtered_probs = torch.full( - (batch_size, k_filter), -float("inf"), device=logits.device + (batch_size, p_filter + 1), -float("inf"), device=logits.device ) sum_filtered_probs = torch.zeros( (batch_size,), device=logits.device, dtype=torch.float32 @@ -746,14 +805,11 @@ def apply_top_p_filtered( _PERCENTILE_TO_STD_TABLE, device=logits.device ) - debug_tensor = torch.full( - (batch_size, vocab_size), -float("inf"), device=logits.device - ) - - top_p_pivot_filter[(NUM_PROGRAMS,)]( - debug_tensor, - logits_maybe_k_masked, - k_filter, + top_k_top_p_filter[(NUM_PROGRAMS,)]( + logits, + (k is not None), + k if k is not None else filtered_indices, + p_filter, buffer, batch_size, sum_filtered_probs, @@ -767,6 +823,10 @@ def apply_top_p_filtered( num_stages=NUM_STAGES, ) + filtered_indices = filtered_indices[:, :p_filter] + filtered_logits = filtered_logits[:, :p_filter] + filtered_probs = filtered_probs[:, :p_filter] + logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) sorted_probs = torch.gather(filtered_probs, -1, sort_indices) @@ -784,7 +844,6 @@ def apply_top_p_filtered( logits.fill_(-float("inf")) logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) return logits - def apply_top_k_top_p_triton( @@ -800,10 +859,5 @@ def apply_top_k_top_p_triton( return logits elif p is None and k is not None: return apply_top_k_only_triton(logits, k) - elif k is None and p is not None: - # k must be supplied for the fallback path - return apply_top_p_filtered(logits, logits, k, p) else: - logits_k_masked = apply_top_k_only_triton(logits, k) - # Original logits and k must be supplied for the fallback path - return apply_top_p_filtered(logits, logits_k_masked, k, p) + return apply_top_p_filtered(logits, k, p) From 6743e12d43f39f421e24b07a7e7f21d780e8bf50 Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 15 Nov 2025 16:11:36 -0800 Subject: [PATCH 53/99] Fast and working correctly Signed-off-by: js_park --- compare.py | 31 +++++++----------- vllm/v1/sample/ops/topk_topp_sampler.py | 43 +++++++++++-------------- 2 files changed, 31 insertions(+), 43 deletions(-) diff --git a/compare.py b/compare.py index b35fa92e3464..aaabf27ab237 100644 --- a/compare.py +++ b/compare.py @@ -3,20 +3,8 @@ from datetime import datetime from itertools import product -import torch - -torch.manual_seed(42) -torch.cuda.manual_seed(42) -torch.cuda.manual_seed_all(42) -import random - -import numpy as np - -random.seed(42) -np.random.seed(42) -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False import regex as re +import torch from vllm.v1.sample.ops.topk_topp_sampler import ( apply_top_k_top_p, @@ -78,8 +66,11 @@ def test_accuracy(logits, k, p, func_list): error_cols = torch.where(error_mask)[1] error_cols = torch.unique(error_cols) num_error_cols = error_cols.shape[0] - print_to_log(f"num_error_rows: {num_error_rows} - {error_rows}", log_file) - print_to_log(f"num_error_cols: {num_error_cols} - {error_cols}", log_file) + print_to_log( + f"num_error_rows: {num_error_rows} - {error_rows}", + f"num_error_cols: {num_error_cols} - {error_cols}", + log_file, + ) row_to_show = 5 if num_error_rows > 5 else num_error_rows logits_to_show = torch.sort( output_logits[error_rows], descending=True @@ -92,7 +83,7 @@ def test_accuracy(logits, k, p, func_list): ).values original_logits_to_show = original_logits_to_show[:row_to_show, :20] print_to_log(f"original_logits: {original_logits_to_show}", log_file) - assert False + raise ValueError("Logits are not close") return output_correct_list @@ -143,8 +134,9 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): with open(csv_file, "w") as f: f.write( - "dist_generator,batch_size,vocab_size,p,k,triton_correct,test_correct" - "torch_time_taken,triton_time_taken,test_time_taken,triton_speedup,test_speedup\n" + "dist_generator,batch_size,vocab_size,p,k,triton_correct,\n" + "test_correct,torch_time_taken,triton_time_taken,test_time_taken,\n" + "triton_speedup,test_speedup\n" ) for batch_size, vocab_size, p, k in product( @@ -190,7 +182,8 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): is_correct = correct_list[i] if not is_correct: print_to_log( - f"Error: logits are not close for function {func_list[i + 1].__name__}," + f"Error: logits are not close for " + f"function {func_list[i + 1].__name__}, " f" batch_size: {batch_size}," f" vocab_size: {vocab_size}, dist_generator: " f"{dist_generator}, p: {p}, k: {k}", diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index f4a02d100dc9..ebbf1c39f1c5 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -367,6 +367,23 @@ def flashinfer_sample( # fmt: on +def apply_top_k_top_p_triton( + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + debug: bool = False, +) -> torch.Tensor: + """ + Uses pivot-based algorithm to filter --> sort + """ + if k is None and p is None: + return logits + elif p is None and k is not None: + return apply_top_k_only_triton(logits, k) + else: + return apply_top_p_filtered(logits, k, p) + + @triton.jit def _topk_triton_kernel( LOGITS, @@ -411,10 +428,6 @@ def _topk_triton_kernel( outlier_pivot = avg_logit + sigma * std_logit num_outliers = tl.zeros((), dtype=tl.uint32) - sum_logit = tl.sum(logits_blk) - min_logit_value = tl.min(logits_blk) - max_logit_value = tl.max(logits_blk) - # First pass: compute max and min logits and gather outliers for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -710,7 +723,7 @@ def top_k_top_p_filter( if p_fil_pivot == -float("inf"): p_fil_pivot = p_fil_pivot_0 - # Third pass: Mask top-k, calculate exp logits and sum + # Third pass: Calculate exp logits and sum with top-k mask if not DO_TOP_K: k_pivot = -float("inf") @@ -736,7 +749,7 @@ def top_k_top_p_filter( probs_blk = probs_blk / sum_exp_logits tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - # Fifth pass : Gather filtered values + # Fifth pass : Gather filtered values with top-k mask write_pos = tl.zeros((), dtype=tl.int32) sum_probs = tl.zeros((), dtype=tl.float32) FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * (P_FIL + 1) @@ -749,7 +762,6 @@ def top_k_top_p_filter( probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) keep_mask = (logits_blk >= p_fil_pivot) & mask_n - n_kept = tl.sum(keep_mask, dtype=tl.int32) cpos = tl.cumsum(keep_mask) - 1 + write_pos final_mask = keep_mask & (cpos < P_FIL) write_idx = tl.where(final_mask, cpos, P_FIL) @@ -844,20 +856,3 @@ def apply_top_p_filtered( logits.fill_(-float("inf")) logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) return logits - - -def apply_top_k_top_p_triton( - logits: torch.Tensor, - k: torch.Tensor | None, - p: torch.Tensor | None, - debug: bool = False, -) -> torch.Tensor: - """ - Uses pivot-based algorithm to filter --> sort - """ - if k is None and p is None: - return logits - elif p is None and k is not None: - return apply_top_k_only_triton(logits, k) - else: - return apply_top_p_filtered(logits, k, p) From 60b6515aff1993883c5bab99b5e632a13dccc11e Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 15 Nov 2025 18:48:40 -0800 Subject: [PATCH 54/99] Errors Signed-off-by: js_park --- compare.py | 72 ++--- vllm/v1/sample/ops/topk_topp_sampler.py | 342 ++++++++++++------------ 2 files changed, 203 insertions(+), 211 deletions(-) diff --git a/compare.py b/compare.py index aaabf27ab237..cde7e1efce84 100644 --- a/compare.py +++ b/compare.py @@ -36,7 +36,7 @@ def print_to_log(s, log_file): f.write(s + "\n") -def test_accuracy(logits, k, p, func_list): +def test_accuracy(logits, k, p, func_list, log_file): input_logit_list = [logits.clone().detach() for i in range(len(func_list))] original_logits = func_list[0](input_logit_list[0], k, p) output_correct_list = [] @@ -67,7 +67,7 @@ def test_accuracy(logits, k, p, func_list): error_cols = torch.unique(error_cols) num_error_cols = error_cols.shape[0] print_to_log( - f"num_error_rows: {num_error_rows} - {error_rows}", + f"num_error_rows: {num_error_rows} - {error_rows}\n" + \ f"num_error_cols: {num_error_cols} - {error_cols}", log_file, ) @@ -112,12 +112,10 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): if __name__ == "__main__": date_str = datetime.now().strftime("%Y%m%d_%H%M%S") - batch_size_list = [64, 128, 1024] - vocab_size_list = [4096, 16384, 65536, 128000, 262144] - p_list = [None, "RAND"] - # p_list = [None] - k_list = [None, "RAND"] - # k_list = [None] + batch_size_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + vocab_size_list = [16384, 65536, 102400, 128256] + p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.99] + k_list = [None, "RAND", 5, 50, 200, 500, 3000] func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] log_file = f"triton_topk_topp_test_{date_str}.log" @@ -134,9 +132,8 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): with open(csv_file, "w") as f: f.write( - "dist_generator,batch_size,vocab_size,p,k,triton_correct,\n" - "test_correct,torch_time_taken,triton_time_taken,test_time_taken,\n" - "triton_speedup,test_speedup\n" + "dist_generator,batch_size,vocab_size,p,k,triton_correct," + "torch_time_taken,triton_time_taken,triton_speedup\n" ) for batch_size, vocab_size, p, k in product( @@ -177,38 +174,23 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): + f"{k}", log_file, ) - correct_list = test_accuracy(logits, k_tensor, p_tensor, func_list) - for i in range(len(func_list) - 1): - is_correct = correct_list[i] - if not is_correct: - print_to_log( - f"Error: logits are not close for " - f"function {func_list[i + 1].__name__}, " - f" batch_size: {batch_size}," - f" vocab_size: {vocab_size}, dist_generator: " - f"{dist_generator}, p: {p}, k: {k}", - log_file, - ) - time_list = [] - for func in func_list: - time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) - time_list.append(time_taken) - print_to_log(b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) - print_to_log(b_str("test_time_taken: ") + f"{time_list[1]}", log_file) + correct_list = \ + test_accuracy(logits, k_tensor, p_tensor, func_list, log_file) + # time_list = [] + # for func in func_list: + # time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) + # time_list.append(time_taken) + # print_to_log(b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) + # print_to_log(b_str("triton_time_taken: ") + f"{time_list[1]}", log_file) # print_to_log( - # b_str("test_time_taken: ") + f"{time_list[2]}", log_file) - print_to_log( - g_str("test Speedup over Torch: ") - + f"{time_list[0] / time_list[1]:.8f}x", - log_file, - ) - # print_to_log( - # y_str("Test Speedup over Torch: ") + - # f"{time_list[0] / time_list[2]:.8f}x", log_file) - with open(csv_file, "a") as f: - f.write( - f"{dist_generator},{batch_size},{vocab_size},{p},{k}," - f"{correct_list[0]},{time_list[0]}," - f"{time_list[0] / time_list[1]:.8f}\n" - ) - print_to_log(y_str("--------------------------------\n"), log_file) + # g_str("test Speedup over Torch: ") + # + f"{time_list[0] / time_list[1]:.8f}x", + # log_file, + # ) + # with open(csv_file, "a") as f: + # f.write( + # f"{dist_generator},{batch_size},{vocab_size},{p},{k}," + # f"{correct_list[0]},{time_list[0]},{time_list[1]}," + # f"{time_list[0] / time_list[1]:.8f}\n" + # ) + # print_to_log(y_str("--------------------------------\n"), log_file) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index ebbf1c39f1c5..6ad8cf1942ed 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import video_processing_utils import torch import torch.nn as nn import triton @@ -376,6 +377,10 @@ def apply_top_k_top_p_triton( """ Uses pivot-based algorithm to filter --> sort """ + # Fallback to torch for small batch sizes + if logits.shape[0] < 16: + return apply_top_k_top_p(logits, k, p) + if k is None and p is None: return logits elif p is None and k is not None: @@ -422,7 +427,7 @@ def _topk_triton_kernel( sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_valid std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - percentile = tl.cast(k * 1.2 / VOCAB_SIZE * 100, tl.uint32) + 1 + percentile = tl.cast(k * 2.0 / VOCAB_SIZE * 100, tl.uint32) + 1 percentile = tl.minimum(percentile, 99) sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) outlier_pivot = avg_logit + sigma * std_logit @@ -432,7 +437,8 @@ def _topk_triton_kernel( for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + logits_blk = \ + tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) @@ -588,7 +594,7 @@ def top_k_top_p_filter( sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_mask std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - percentile = tl.cast(P_FIL * 1.2 / VOCAB_SIZE * 100, tl.uint32) + 1 + percentile = tl.cast(P_FIL * 2.0 / VOCAB_SIZE * 100 + 4, tl.uint32) percentile = tl.minimum(percentile, 99) sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) outlier_pivot = avg_logit + sigma * std_logit @@ -598,7 +604,8 @@ def top_k_top_p_filter( for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) + logits_blk = \ + tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) @@ -624,159 +631,166 @@ def top_k_top_p_filter( search_iters = tl.cast( (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 ) + + k = tl.load(K + row_id) + k_max_range = max_range + k_min_range = min_range + p_fil_max_range = max_range + p_fil_min_range = min_range + + # Second passes: Quaternary search for pivots (nlog_4(n)) + k_pivot = -float("inf") + p_fil_pivot = -float("inf") + num_iters = 0 + while k_pivot == -float("inf") or p_fil_pivot == -float("inf"): + k_pivot_0 = (k_max_range - k_min_range) * 1.0 / 4.0 + k_min_range + k_pivot_1 = (k_max_range - k_min_range) * 2.0 / 4.0 + k_min_range + k_pivot_2 = (k_max_range - k_min_range) * 3.0 / 4.0 + k_min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + + p_fil_pivot_0 = ( + p_fil_max_range - p_fil_min_range + ) * 1.0 / 4.0 + p_fil_min_range + p_fil_pivot_1 = ( + p_fil_max_range - p_fil_min_range + ) * 2.0 / 4.0 + p_fil_min_range + p_fil_pivot_2 = ( + p_fil_max_range - p_fil_min_range + ) * 3.0 / 4.0 + p_fil_min_range + p_fil_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + p_fil_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + p_fil_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") + ) - k = tl.load(K + row_id) - k_max_range = max_range - k_min_range = min_range - p_fil_max_range = max_range - p_fil_min_range = min_range - - # Second passes: Quaternary search for pivots (nlog_4(n)) - k_pivot = -float("inf") - p_fil_pivot = -float("inf") - num_iters = 0 - while k_pivot == -float("inf") or p_fil_pivot == -float("inf"): - k_pivot_0 = (k_max_range - k_min_range) * 1.0 / 4.0 + k_min_range - k_pivot_1 = (k_max_range - k_min_range) * 2.0 / 4.0 + k_min_range - k_pivot_2 = (k_max_range - k_min_range) * 3.0 / 4.0 + k_min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - p_fil_pivot_0 = ( - p_fil_max_range - p_fil_min_range - ) * 1.0 / 4.0 + p_fil_min_range - p_fil_pivot_1 = ( - p_fil_max_range - p_fil_min_range - ) * 2.0 / 4.0 + p_fil_min_range - p_fil_pivot_2 = ( - p_fil_max_range - p_fil_min_range - ) * 3.0 / 4.0 + p_fil_min_range - p_fil_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - p_fil_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - p_fil_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) - - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - - p_fil_pivots_num_0 += tl.sum(logits_blk > p_fil_pivot_0) - p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) - p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - # Check if any of the pivots are equal to k - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we update the range - elif k_pivots_num_2 > k: - k_min_range = k_pivot_2 - elif k_pivots_num_1 > k: - k_min_range = k_pivot_1 - elif k_pivots_num_0 > k: - k_min_range = k_pivot_0 - if k_pivots_num_0 < k: - k_max_range = k_pivot_0 - elif k_pivots_num_1 < k: - k_max_range = k_pivot_1 - elif k_pivots_num_2 < k: - k_max_range = k_pivot_2 - - # Check if any of the pivots are equal to P_FIL - if p_fil_pivots_num_0 == P_FIL: - p_fil_pivot = p_fil_pivot_0 - elif p_fil_pivots_num_1 == P_FIL: - p_fil_pivot = p_fil_pivot_1 - elif p_fil_pivots_num_2 == P_FIL: - p_fil_pivot = p_fil_pivot_2 - # If none of the pivots are equal to P_FIL, we update the range - elif p_fil_pivots_num_2 > P_FIL: - p_fil_min_range = p_fil_pivot_2 - elif p_fil_pivots_num_1 > P_FIL: - p_fil_min_range = p_fil_pivot_1 - elif p_fil_pivots_num_0 > P_FIL: - p_fil_min_range = p_fil_pivot_0 - if p_fil_pivots_num_0 < P_FIL: - p_fil_max_range = p_fil_pivot_0 - elif p_fil_pivots_num_1 < P_FIL: - p_fil_max_range = p_fil_pivot_1 - elif p_fil_pivots_num_2 < P_FIL: - p_fil_max_range = p_fil_pivot_2 + p_fil_pivots_num_0 += tl.sum(logits_blk > p_fil_pivot_0) + p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) + p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) - num_iters += 1 - if num_iters >= 32 or ( - tl.abs(k_min_range - k_max_range) < 1e-16 - and tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 - ): + # Check if any of the pivots are equal to k if k_pivot == -float("inf"): - k_pivot = k_pivot_0 + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we update the range + elif k_pivots_num_2 > k: + k_min_range = k_pivot_2 + elif k_pivots_num_1 > k: + k_min_range = k_pivot_1 + elif k_pivots_num_0 > k: + k_min_range = k_pivot_0 + if k_pivots_num_0 < k: + k_max_range = k_pivot_0 + elif k_pivots_num_1 < k: + k_max_range = k_pivot_1 + elif k_pivots_num_2 < k: + k_max_range = k_pivot_2 + + # Check if any of the pivots are equal to P_FIL if p_fil_pivot == -float("inf"): - p_fil_pivot = p_fil_pivot_0 - - # Third pass: Calculate exp logits and sum with top-k mask - if not DO_TOP_K: - k_pivot = -float("inf") - - sum_exp_logits = tl.zeros((), dtype=tl.float32) - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) - - top_k_mask = logits_blk > k_pivot - logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) - - probs_blk = logits_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - - # Fourth pass: Calculate softmax - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - - # Fifth pass : Gather filtered values with top-k mask - write_pos = tl.zeros((), dtype=tl.int32) - sum_probs = tl.zeros((), dtype=tl.float32) - FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * (P_FIL + 1) - FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * (P_FIL + 1) - FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * (P_FIL + 1) - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + if p_fil_pivots_num_0 == P_FIL: + p_fil_pivot = p_fil_pivot_0 + elif p_fil_pivots_num_1 == P_FIL: + p_fil_pivot = p_fil_pivot_1 + elif p_fil_pivots_num_2 == P_FIL: + p_fil_pivot = p_fil_pivot_2 + # If none of the pivots are equal to P_FIL, we update the range + elif p_fil_pivots_num_2 > P_FIL: + p_fil_min_range = p_fil_pivot_2 + elif p_fil_pivots_num_1 > P_FIL: + p_fil_min_range = p_fil_pivot_1 + elif p_fil_pivots_num_0 > P_FIL: + p_fil_min_range = p_fil_pivot_0 + if p_fil_pivots_num_0 < P_FIL: + p_fil_max_range = p_fil_pivot_0 + elif p_fil_pivots_num_1 < P_FIL: + p_fil_max_range = p_fil_pivot_1 + elif p_fil_pivots_num_2 < P_FIL: + p_fil_max_range = p_fil_pivot_2 + + num_iters += 1 + if num_iters >= 32 or ( + tl.abs(k_min_range - k_max_range) < 1e-16 + and tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 + ): + if k_pivot == -float("inf"): + k_pivot = k_pivot_0 + if p_fil_pivot == -float("inf"): + p_fil_pivot = p_fil_pivot_0 + + # Third pass: Calculate exp logits and sum with top-k mask + if not DO_TOP_K or k == VOCAB_SIZE: + k_pivot = -float("inf") + + sum_exp_logits = tl.zeros((), dtype=tl.float32) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = \ + tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) - keep_mask = (logits_blk >= p_fil_pivot) & mask_n - cpos = tl.cumsum(keep_mask) - 1 + write_pos - final_mask = keep_mask & (cpos < P_FIL) - write_idx = tl.where(final_mask, cpos, P_FIL) + top_k_mask = logits_blk > k_pivot + logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) - top_k_mask = logits_blk > k_pivot - logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) + probs_blk = logits_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - # Gather filtered values - tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=final_mask) - tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=final_mask) - tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=final_mask) + # Fourth pass: Calculate softmax + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + # Fifth pass : Gather filtered values with top-k mask + write_pos = tl.zeros((), dtype=tl.int32) + sum_probs = tl.zeros((), dtype=tl.float32) + FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * P_FIL + FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * P_FIL + FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * P_FIL + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = \ + tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + keep_mask = (logits_blk > p_fil_pivot) & mask_n + cpos = tl.cumsum(keep_mask) - 1 + write_pos + f_mask = keep_mask & (cpos < P_FIL) + write_idx = tl.where(f_mask, cpos, 0) + + top_k_mask = logits_blk > k_pivot + logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) + + # Gather filtered values + tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=f_mask) + tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=f_mask) + tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=f_mask) + + sum_probs += tl.sum(probs_blk * f_mask) + write_pos += tl.sum(f_mask, dtype=tl.int32) + tl.store(SUM_FILTERED_PROBS + row_id, sum_probs) + else: + tl.store(SUM_FILTERED_PROBS + row_id, 0.0) - sum_probs += tl.sum(probs_blk * final_mask) - write_pos += tl.sum(final_mask, dtype=tl.int32) - tl.store(SUM_FILTERED_PROBS + row_id, sum_probs) def apply_top_p_filtered( @@ -787,32 +801,33 @@ def apply_top_p_filtered( """ Applies top p using pivot based filtering """ - batch_size, vocab_size = logits.shape + max_k = k.max().item() if k is not None else 0 + if max_k > vocab_size / 10: + return apply_top_k_top_p(logits, k, p) + BLOCK_SIZE = 8192 device_prop = torch.cuda.get_device_properties(logits.device) NUM_PROGRAMS = device_prop.multi_processor_count # Persistent kernel + NUM_WARPS = 16 + NUM_STAGES = 3 buffer = torch.empty( (NUM_PROGRAMS, vocab_size), device=logits.device, dtype=torch.float32 ) - NUM_WARPS = 16 - NUM_STAGES = 3 - - p_filter = k.max().item() if k is not None else int(vocab_size * 1 / 32) + p_filter = int(max_k * 1.2) if k is not None else int(vocab_size / 32) filtered_logits = torch.full( - (batch_size, p_filter + 1), -float("inf"), device=logits.device + (batch_size, p_filter), -float("inf"), device=logits.device ) filtered_indices = torch.full( - (batch_size, p_filter + 1), p_filter, dtype=torch.int64, device=logits.device + (batch_size, p_filter), p_filter, dtype=torch.int64, device=logits.device ) filtered_probs = torch.full( - (batch_size, p_filter + 1), -float("inf"), device=logits.device + (batch_size, p_filter), -float("inf"), device=logits.device ) sum_filtered_probs = torch.zeros( (batch_size,), device=logits.device, dtype=torch.float32 ) - PERCENTILE_TO_STD_TABLE = torch.tensor( _PERCENTILE_TO_STD_TABLE, device=logits.device ) @@ -835,21 +850,16 @@ def apply_top_p_filtered( num_stages=NUM_STAGES, ) - filtered_indices = filtered_indices[:, :p_filter] - filtered_logits = filtered_logits[:, :p_filter] - filtered_probs = filtered_probs[:, :p_filter] + if torch.any(sum_filtered_probs < p): + return apply_top_k_top_p(logits, k, p) logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) sorted_probs = torch.gather(filtered_probs, -1, sort_indices) - if torch.any(sum_filtered_probs < p): - return apply_top_k_top_p(logits, k, p) - - probs_sum = torch.cumsum(sorted_probs, dim=-1) sum_non_outliers = (1.0 - sum_filtered_probs).unsqueeze(-1) - probs_sum = probs_sum + sum_non_outliers - top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=-1)) + probs_sum = torch.cumsum(sorted_probs, dim=-1) + sum_non_outliers + top_p_mask = probs_sum < (1 - p.unsqueeze(dim=-1)) top_p_mask[:, -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) From 71cbb9efce9714c9b0a2286f0d2e4947bb43d326 Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 15 Nov 2025 19:26:24 -0800 Subject: [PATCH 55/99] Filtered logits are wrongs Signed-off-by: js_park --- compare.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 87 +++++++++++++++---------- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/compare.py b/compare.py index cde7e1efce84..fe16570a9f96 100644 --- a/compare.py +++ b/compare.py @@ -112,7 +112,7 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): if __name__ == "__main__": date_str = datetime.now().strftime("%Y%m%d_%H%M%S") - batch_size_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + batch_size_list = [16, 32, 64, 128, 256, 512, 1024] vocab_size_list = [16384, 65536, 102400, 128256] p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.99] k_list = [None, "RAND", 5, 50, 200, 500, 3000] diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 6ad8cf1942ed..99fd9de08423 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -386,7 +386,8 @@ def apply_top_k_top_p_triton( elif p is None and k is not None: return apply_top_k_only_triton(logits, k) else: - return apply_top_p_filtered(logits, k, p) + logits_top_k = apply_top_k_only_triton(logits, k) + return apply_top_p_filtered(logits_top_k, logits, k, p) @triton.jit @@ -442,6 +443,7 @@ def _topk_triton_kernel( max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + outlier_mask = (logits_blk > outlier_pivot) & mask_n num_blk_outliers = tl.sum(outlier_mask) cumulative_pos = tl.cast( @@ -529,6 +531,8 @@ def apply_top_k_only_triton( The logits tensor will be updated out-of-place. """ + if k is None: + return logits batch_size, vocab_size = logits.shape NUM_PROGRAMS = batch_size # Non-persistent kernel @@ -618,7 +622,7 @@ def top_k_top_p_filter( ) num_outliers += num_blk_outliers - write_idx = tl.where(outlier_mask, cumulative_pos, 0) + write_idx = tl.where(outlier_mask, cumulative_pos, -1) tl.store(BUFFER_ROW + write_idx, logits_blk, mask=outlier_mask) max_range = max_logit @@ -639,9 +643,9 @@ def top_k_top_p_filter( p_fil_min_range = min_range # Second passes: Quaternary search for pivots (nlog_4(n)) + num_iters = 0 k_pivot = -float("inf") p_fil_pivot = -float("inf") - num_iters = 0 while k_pivot == -float("inf") or p_fil_pivot == -float("inf"): k_pivot_0 = (k_max_range - k_min_range) * 1.0 / 4.0 + k_min_range k_pivot_1 = (k_max_range - k_min_range) * 2.0 / 4.0 + k_min_range @@ -686,19 +690,19 @@ def top_k_top_p_filter( k_pivot = k_pivot_1 elif k_pivots_num_2 == k: k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we update the range - elif k_pivots_num_2 > k: - k_min_range = k_pivot_2 - elif k_pivots_num_1 > k: - k_min_range = k_pivot_1 - elif k_pivots_num_0 > k: - k_min_range = k_pivot_0 - if k_pivots_num_0 < k: - k_max_range = k_pivot_0 - elif k_pivots_num_1 < k: - k_max_range = k_pivot_1 - elif k_pivots_num_2 < k: - k_max_range = k_pivot_2 + # If none of the pivots are equal to k, we update the range + elif k_pivots_num_2 > k: + k_min_range = k_pivot_2 + elif k_pivots_num_1 > k: + k_min_range = k_pivot_1 + elif k_pivots_num_0 > k: + k_min_range = k_pivot_0 + if k_pivots_num_0 < k: + k_max_range = k_pivot_0 + elif k_pivots_num_1 < k: + k_max_range = k_pivot_1 + elif k_pivots_num_2 < k: + k_max_range = k_pivot_2 # Check if any of the pivots are equal to P_FIL if p_fil_pivot == -float("inf"): @@ -708,24 +712,24 @@ def top_k_top_p_filter( p_fil_pivot = p_fil_pivot_1 elif p_fil_pivots_num_2 == P_FIL: p_fil_pivot = p_fil_pivot_2 - # If none of the pivots are equal to P_FIL, we update the range - elif p_fil_pivots_num_2 > P_FIL: - p_fil_min_range = p_fil_pivot_2 - elif p_fil_pivots_num_1 > P_FIL: - p_fil_min_range = p_fil_pivot_1 - elif p_fil_pivots_num_0 > P_FIL: - p_fil_min_range = p_fil_pivot_0 - if p_fil_pivots_num_0 < P_FIL: - p_fil_max_range = p_fil_pivot_0 - elif p_fil_pivots_num_1 < P_FIL: - p_fil_max_range = p_fil_pivot_1 - elif p_fil_pivots_num_2 < P_FIL: - p_fil_max_range = p_fil_pivot_2 + # If none of the pivots are equal to P_FIL, we update the range + elif p_fil_pivots_num_2 > P_FIL: + p_fil_min_range = p_fil_pivot_2 + elif p_fil_pivots_num_1 > P_FIL: + p_fil_min_range = p_fil_pivot_1 + elif p_fil_pivots_num_0 > P_FIL: + p_fil_min_range = p_fil_pivot_0 + if p_fil_pivots_num_0 < P_FIL: + p_fil_max_range = p_fil_pivot_0 + elif p_fil_pivots_num_1 < P_FIL: + p_fil_max_range = p_fil_pivot_1 + elif p_fil_pivots_num_2 < P_FIL: + p_fil_max_range = p_fil_pivot_2 num_iters += 1 if num_iters >= 32 or ( - tl.abs(k_min_range - k_max_range) < 1e-16 - and tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 + (tl.abs(k_min_range - k_max_range) < 1e-16 and k_pivot != -float("inf")) + and (tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 and p_fil_pivot != -float("inf")) ): if k_pivot == -float("inf"): k_pivot = k_pivot_0 @@ -781,6 +785,7 @@ def top_k_top_p_filter( logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) # Gather filtered values + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=f_mask) tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=f_mask) tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=f_mask) @@ -794,6 +799,7 @@ def top_k_top_p_filter( def apply_top_p_filtered( + logits_top_k: torch.Tensor, logits: torch.Tensor, k: torch.Tensor, p: torch.Tensor, @@ -805,6 +811,7 @@ def apply_top_p_filtered( max_k = k.max().item() if k is not None else 0 if max_k > vocab_size / 10: + print("Max k too large, falling back to top-k-top-p") return apply_top_k_top_p(logits, k, p) BLOCK_SIZE = 8192 @@ -849,13 +856,23 @@ def apply_top_p_filtered( num_warps=NUM_WARPS, num_stages=NUM_STAGES, ) - - if torch.any(sum_filtered_probs < p): - return apply_top_k_top_p(logits, k, p) + + print(f"filtered_logits: {filtered_logits.shape}") logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) - sorted_probs = torch.gather(filtered_probs, -1, sort_indices) + + assert torch.allclose(logits_top_k, logits), "Top k Logits are not close" + + + logit_probs = torch.softmax(logits, dim=-1) + sorted_probs = torch.gather(logit_probs, -1, logits_sort_indices) + sum_filtered_probs = torch.sum(sorted_probs, dim=-1) + # sorted_probs = torch.gather(filtered_probs, -1, sort_indices) + + if torch.any(sum_filtered_probs < p): + print("Falling back to top-k-top-p") + return apply_top_k_top_p(logits, k, p) sum_non_outliers = (1.0 - sum_filtered_probs).unsqueeze(-1) probs_sum = torch.cumsum(sorted_probs, dim=-1) + sum_non_outliers From 8b0771ceb74e9ecf21d9f739c3e99178334e0c70 Mon Sep 17 00:00:00 2001 From: js_park Date: Sat, 15 Nov 2025 23:08:42 -0800 Subject: [PATCH 56/99] Filtered logits are wrongs Signed-off-by: js_park --- compare.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/compare.py b/compare.py index fe16570a9f96..37cd4d406723 100644 --- a/compare.py +++ b/compare.py @@ -58,7 +58,7 @@ def test_accuracy(logits, k, p, func_list, log_file): r_str("Error: logits are not close on " + f"{func_name}"), log_file, ) - output_logits = func_list[i](logits, k, p, debug=True) + output_logits = func_list[i](logits, k, p) error_mask = torch.abs(output_logits - original_logits) > 1e-16 error_rows = torch.where(error_mask)[0] error_rows = torch.unique(error_rows) @@ -114,7 +114,8 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): batch_size_list = [16, 32, 64, 128, 256, 512, 1024] vocab_size_list = [16384, 65536, 102400, 128256] - p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.99] + # p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.99] + p_list = [None for _ in range(10)] k_list = [None, "RAND", 5, 50, 200, 500, 3000] func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] From 20806a2796ff8c4407c47361215591074360eda9 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 01:29:46 -0800 Subject: [PATCH 57/99] Floating point associativity errors remain Signed-off-by: js_park --- compare.py | 51 ++-- vllm/v1/sample/ops/topk_topp_sampler.py | 357 ++++++++++++------------ 2 files changed, 200 insertions(+), 208 deletions(-) diff --git a/compare.py b/compare.py index 37cd4d406723..62981dfe19ff 100644 --- a/compare.py +++ b/compare.py @@ -44,11 +44,12 @@ def test_accuracy(logits, k, p, func_list, log_file): output_logits = func_list[i](input_logit_list[i], k, p) torch.cuda.synchronize() - original_logits_bin = original_logits.view(torch.int32) - output_logits_bin = output_logits.view(torch.int32) - is_correct = torch.all(original_logits_bin == output_logits_bin) + is_correct = True + # original_logits_bin = original_logits.view(torch.int32) + # output_logits_bin = output_logits.view(torch.int32) + # is_correct = torch.all(original_logits_bin == output_logits_bin) is_correct = is_correct and torch.allclose( - output_logits, original_logits, atol=1e-16 + output_logits, original_logits ) output_correct_list.append(is_correct) func_name = func_list[i].__name__ @@ -59,7 +60,7 @@ def test_accuracy(logits, k, p, func_list, log_file): log_file, ) output_logits = func_list[i](logits, k, p) - error_mask = torch.abs(output_logits - original_logits) > 1e-16 + error_mask = torch.abs(output_logits - original_logits) > 1e-12 error_rows = torch.where(error_mask)[0] error_rows = torch.unique(error_rows) num_error_rows = error_rows.shape[0] @@ -114,8 +115,8 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): batch_size_list = [16, 32, 64, 128, 256, 512, 1024] vocab_size_list = [16384, 65536, 102400, 128256] - # p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.99] - p_list = [None for _ in range(10)] + p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.95, 0.99] + # k_list = [None for _ in range(100)] k_list = [None, "RAND", 5, 50, 200, 500, 3000] func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] @@ -177,21 +178,21 @@ def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): ) correct_list = \ test_accuracy(logits, k_tensor, p_tensor, func_list, log_file) - # time_list = [] - # for func in func_list: - # time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) - # time_list.append(time_taken) - # print_to_log(b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) - # print_to_log(b_str("triton_time_taken: ") + f"{time_list[1]}", log_file) - # print_to_log( - # g_str("test Speedup over Torch: ") - # + f"{time_list[0] / time_list[1]:.8f}x", - # log_file, - # ) - # with open(csv_file, "a") as f: - # f.write( - # f"{dist_generator},{batch_size},{vocab_size},{p},{k}," - # f"{correct_list[0]},{time_list[0]},{time_list[1]}," - # f"{time_list[0] / time_list[1]:.8f}\n" - # ) - # print_to_log(y_str("--------------------------------\n"), log_file) + time_list = [] + for func in func_list: + time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) + time_list.append(time_taken) + print_to_log(b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) + print_to_log(b_str("triton_time_taken: ") + f"{time_list[1]}", log_file) + print_to_log( + g_str("test Speedup over Torch: ") + + f"{time_list[0] / time_list[1]:.8f}x", + log_file, + ) + with open(csv_file, "a") as f: + f.write( + f"{dist_generator},{batch_size},{vocab_size},{p},{k}," + f"{correct_list[0]},{time_list[0]},{time_list[1]}," + f"{time_list[0] / time_list[1]:.8f}\n" + ) + print_to_log(y_str("--------------------------------\n"), log_file) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 99fd9de08423..c086a36718ae 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -216,7 +216,6 @@ def apply_top_k_top_p( # Avoid sorting vocab for top-k only case. return apply_top_k_only(logits, k) - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) if k is not None: @@ -386,8 +385,7 @@ def apply_top_k_top_p_triton( elif p is None and k is not None: return apply_top_k_only_triton(logits, k) else: - logits_top_k = apply_top_k_only_triton(logits, k) - return apply_top_p_filtered(logits_top_k, logits, k, p) + return apply_top_k_top_p_filtered(logits, k, p) @triton.jit @@ -563,10 +561,11 @@ def top_k_top_p_filter( LOGITS, DO_TOP_K, K, + P, P_FIL, BUFFER, BATCH_SIZE, - SUM_FILTERED_PROBS, + SUM_EXCLUDED_PROBS, FILTERED_LOGITS, FILTERED_INDICES, FILTERED_PROBS, @@ -625,181 +624,178 @@ def top_k_top_p_filter( write_idx = tl.where(outlier_mask, cumulative_pos, -1) tl.store(BUFFER_ROW + write_idx, logits_blk, mask=outlier_mask) - max_range = max_logit - min_range = min_logit + k_max_range = max_logit + k_min_range = min_logit + p_fil_max_range = max_logit + p_fil_min_range = min_logit + if num_outliers > P_FIL: - max_range = max_logit - min_range = outlier_pivot search_addr = BUFFER_ROW search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 ) - - k = tl.load(K + row_id) - k_max_range = max_range - k_min_range = min_range - p_fil_max_range = max_range - p_fil_min_range = min_range - - # Second passes: Quaternary search for pivots (nlog_4(n)) - num_iters = 0 - k_pivot = -float("inf") - p_fil_pivot = -float("inf") - while k_pivot == -float("inf") or p_fil_pivot == -float("inf"): - k_pivot_0 = (k_max_range - k_min_range) * 1.0 / 4.0 + k_min_range - k_pivot_1 = (k_max_range - k_min_range) * 2.0 / 4.0 + k_min_range - k_pivot_2 = (k_max_range - k_min_range) * 3.0 / 4.0 + k_min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - p_fil_pivot_0 = ( - p_fil_max_range - p_fil_min_range - ) * 1.0 / 4.0 + p_fil_min_range - p_fil_pivot_1 = ( - p_fil_max_range - p_fil_min_range - ) * 2.0 / 4.0 + p_fil_min_range - p_fil_pivot_2 = ( - p_fil_max_range - p_fil_min_range - ) * 3.0 / 4.0 + p_fil_min_range - p_fil_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - p_fil_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - p_fil_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) + k_min_range = outlier_pivot + p_fil_min_range = outlier_pivot + + k = tl.load(K + row_id) + + # Second passes: Quaternary search for pivots (nlog_4(n)) + num_iters = 0 + k_pivot = -float("inf") + p_fil_pivot = -float("inf") + while k_pivot == -float("inf") or p_fil_pivot == -float("inf"): + k_pivot_0 = (k_max_range - k_min_range) * 1.0 / 4.0 + k_min_range + k_pivot_1 = (k_max_range - k_min_range) * 2.0 / 4.0 + k_min_range + k_pivot_2 = (k_max_range - k_min_range) * 3.0 / 4.0 + k_min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + p_fil_pivot_0 = ( + p_fil_max_range - p_fil_min_range + ) * 1.0 / 4.0 + p_fil_min_range + p_fil_pivot_1 = ( + p_fil_max_range - p_fil_min_range + ) * 2.0 / 4.0 + p_fil_min_range + p_fil_pivot_2 = ( + p_fil_max_range - p_fil_min_range + ) * 3.0 / 4.0 + p_fil_min_range + p_fil_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + p_fil_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + p_fil_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - p_fil_pivots_num_0 += tl.sum(logits_blk > p_fil_pivot_0) - p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) - p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") + ) + + k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) + + p_fil_pivots_num_0 += tl.sum(logits_blk > p_fil_pivot_0) + p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) + p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) + + # Check if any of the pivots are equal to k + if k_pivot == -float("inf"): + if k_pivots_num_0 == k: + k_pivot = k_pivot_0 + elif k_pivots_num_1 == k: + k_pivot = k_pivot_1 + elif k_pivots_num_2 == k: + k_pivot = k_pivot_2 + # If none of the pivots are equal to k, we update the range + elif k_pivots_num_2 > k: + k_min_range = k_pivot_2 + elif k_pivots_num_1 > k: + k_min_range = k_pivot_1 + elif k_pivots_num_0 > k: + k_min_range = k_pivot_0 + if k_pivots_num_0 < k: + k_max_range = k_pivot_0 + elif k_pivots_num_1 < k: + k_max_range = k_pivot_1 + elif k_pivots_num_2 < k: + k_max_range = k_pivot_2 + + # Check if any of the pivots are equal to P_FIL + if p_fil_pivot == -float("inf"): + if p_fil_pivots_num_0 == P_FIL: + p_fil_pivot = p_fil_pivot_0 + elif p_fil_pivots_num_1 == P_FIL: + p_fil_pivot = p_fil_pivot_1 + elif p_fil_pivots_num_2 == P_FIL: + p_fil_pivot = p_fil_pivot_2 + # If none of the pivots are equal to P_FIL, we update the range + elif p_fil_pivots_num_2 > P_FIL: + p_fil_min_range = p_fil_pivot_2 + elif p_fil_pivots_num_1 > P_FIL: + p_fil_min_range = p_fil_pivot_1 + elif p_fil_pivots_num_0 > P_FIL: + p_fil_min_range = p_fil_pivot_0 + if p_fil_pivots_num_0 < P_FIL: + p_fil_max_range = p_fil_pivot_0 + elif p_fil_pivots_num_1 < P_FIL: + p_fil_max_range = p_fil_pivot_1 + elif p_fil_pivots_num_2 < P_FIL: + p_fil_max_range = p_fil_pivot_2 - # Check if any of the pivots are equal to k + num_iters += 1 + if num_iters >= 32 or ( + (tl.abs(k_min_range - k_max_range) < 1e-16 and k_pivot != -float("inf")) + and (tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 and p_fil_pivot != -float("inf")) + ): if k_pivot == -float("inf"): - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we update the range - elif k_pivots_num_2 > k: - k_min_range = k_pivot_2 - elif k_pivots_num_1 > k: - k_min_range = k_pivot_1 - elif k_pivots_num_0 > k: - k_min_range = k_pivot_0 - if k_pivots_num_0 < k: - k_max_range = k_pivot_0 - elif k_pivots_num_1 < k: - k_max_range = k_pivot_1 - elif k_pivots_num_2 < k: - k_max_range = k_pivot_2 - - # Check if any of the pivots are equal to P_FIL + k_pivot = k_pivot_0 if p_fil_pivot == -float("inf"): - if p_fil_pivots_num_0 == P_FIL: - p_fil_pivot = p_fil_pivot_0 - elif p_fil_pivots_num_1 == P_FIL: - p_fil_pivot = p_fil_pivot_1 - elif p_fil_pivots_num_2 == P_FIL: - p_fil_pivot = p_fil_pivot_2 - # If none of the pivots are equal to P_FIL, we update the range - elif p_fil_pivots_num_2 > P_FIL: - p_fil_min_range = p_fil_pivot_2 - elif p_fil_pivots_num_1 > P_FIL: - p_fil_min_range = p_fil_pivot_1 - elif p_fil_pivots_num_0 > P_FIL: - p_fil_min_range = p_fil_pivot_0 - if p_fil_pivots_num_0 < P_FIL: - p_fil_max_range = p_fil_pivot_0 - elif p_fil_pivots_num_1 < P_FIL: - p_fil_max_range = p_fil_pivot_1 - elif p_fil_pivots_num_2 < P_FIL: - p_fil_max_range = p_fil_pivot_2 - - num_iters += 1 - if num_iters >= 32 or ( - (tl.abs(k_min_range - k_max_range) < 1e-16 and k_pivot != -float("inf")) - and (tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 and p_fil_pivot != -float("inf")) - ): - if k_pivot == -float("inf"): - k_pivot = k_pivot_0 - if p_fil_pivot == -float("inf"): - p_fil_pivot = p_fil_pivot_0 - - # Third pass: Calculate exp logits and sum with top-k mask - if not DO_TOP_K or k == VOCAB_SIZE: - k_pivot = -float("inf") - - sum_exp_logits = tl.zeros((), dtype=tl.float32) - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = \ - tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + p_fil_pivot = p_fil_pivot_0 - top_k_mask = logits_blk > k_pivot - logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) + # Third pass: Calculate exp logits and sum with top-k mask + if not DO_TOP_K or k == VOCAB_SIZE: + k_pivot = -float("inf") - probs_blk = logits_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + sum_exp_logits = tl.zeros((), dtype=tl.float32) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = \ + tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) - # Fourth pass: Calculate softmax - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - - # Fifth pass : Gather filtered values with top-k mask - write_pos = tl.zeros((), dtype=tl.int32) - sum_probs = tl.zeros((), dtype=tl.float32) - FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * P_FIL - FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * P_FIL - FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * P_FIL - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = \ - tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - - keep_mask = (logits_blk > p_fil_pivot) & mask_n - cpos = tl.cumsum(keep_mask) - 1 + write_pos - f_mask = keep_mask & (cpos < P_FIL) - write_idx = tl.where(f_mask, cpos, 0) - - top_k_mask = logits_blk > k_pivot - logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) - - # Gather filtered values - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) - tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=f_mask) - tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=f_mask) - tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=f_mask) - - sum_probs += tl.sum(probs_blk * f_mask) - write_pos += tl.sum(f_mask, dtype=tl.int32) - tl.store(SUM_FILTERED_PROBS + row_id, sum_probs) - else: - tl.store(SUM_FILTERED_PROBS + row_id, 0.0) + top_k_mask = logits_blk > k_pivot + logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) + probs_blk = logits_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + # Fourth pass: Calculate softmax + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + # Fifth pass : Gather filtered values with top-k mask + write_pos = tl.zeros((), dtype=tl.int32) + sum_excluded_probs = tl.zeros((), dtype=tl.float32) + FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * P_FIL + FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * P_FIL + FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * P_FIL + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = \ + tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) -def apply_top_p_filtered( - logits_top_k: torch.Tensor, + keep_mask = (logits_blk > p_fil_pivot) & mask_n + cpos = tl.cumsum(keep_mask) - 1 + write_pos + f_mask = keep_mask & (cpos < P_FIL) + write_idx = tl.where(f_mask, cpos, 0) + + top_k_mask = logits_blk > k_pivot + logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) + + # Gather filtered values + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=f_mask) + tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=f_mask) + tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=f_mask) + + sum_excluded_probs += tl.sum(probs_blk * ((~f_mask) & mask_n)) + write_pos += tl.sum(f_mask, dtype=tl.int32) + tl.store(SUM_EXCLUDED_PROBS + row_id, sum_excluded_probs) + + + + +def apply_top_k_top_p_filtered( logits: torch.Tensor, k: torch.Tensor, p: torch.Tensor, @@ -809,9 +805,13 @@ def apply_top_p_filtered( """ batch_size, vocab_size = logits.shape + # If k is too large, speedup is not significant as the filtered set is large. max_k = k.max().item() if k is not None else 0 - if max_k > vocab_size / 10: - print("Max k too large, falling back to top-k-top-p") + # Probabilty value is not guaranteed to be equivalent to the PyTorch implementation + # in the distribution tail due to floating point non-associativity. + # We avoid high p values to avoid this accuracy issue. + max_p = p.max().item() if p is not None else 0 + if max_k > vocab_size / 10 or max_p > 0.95: return apply_top_k_top_p(logits, k, p) BLOCK_SIZE = 8192 @@ -832,7 +832,7 @@ def apply_top_p_filtered( filtered_probs = torch.full( (batch_size, p_filter), -float("inf"), device=logits.device ) - sum_filtered_probs = torch.zeros( + sum_excluded_probs = torch.zeros( (batch_size,), device=logits.device, dtype=torch.float32 ) PERCENTILE_TO_STD_TABLE = torch.tensor( @@ -843,10 +843,11 @@ def apply_top_p_filtered( logits, (k is not None), k if k is not None else filtered_indices, + p, p_filter, buffer, batch_size, - sum_filtered_probs, + sum_excluded_probs, filtered_logits, filtered_indices, filtered_probs, @@ -856,27 +857,17 @@ def apply_top_p_filtered( num_warps=NUM_WARPS, num_stages=NUM_STAGES, ) - - print(f"filtered_logits: {filtered_logits.shape}") + if torch.any(sum_excluded_probs >= p): + return apply_top_k_top_p(logits, k, p) + logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) + sorted_probs = torch.gather(filtered_probs, -1, sort_indices) - assert torch.allclose(logits_top_k, logits), "Top k Logits are not close" - - - logit_probs = torch.softmax(logits, dim=-1) - sorted_probs = torch.gather(logit_probs, -1, logits_sort_indices) - sum_filtered_probs = torch.sum(sorted_probs, dim=-1) - # sorted_probs = torch.gather(filtered_probs, -1, sort_indices) - - if torch.any(sum_filtered_probs < p): - print("Falling back to top-k-top-p") - return apply_top_k_top_p(logits, k, p) - - sum_non_outliers = (1.0 - sum_filtered_probs).unsqueeze(-1) - probs_sum = torch.cumsum(sorted_probs, dim=-1) + sum_non_outliers - top_p_mask = probs_sum < (1 - p.unsqueeze(dim=-1)) + sorted_probs[:, 0] = sorted_probs[:, 0] + sum_excluded_probs + probs_sum = torch.cumsum(sorted_probs, dim=-1) + top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=-1)) top_p_mask[:, -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) From 89443c01b230384c6d7287d369797242af65fe3f Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 01:48:14 -0800 Subject: [PATCH 58/99] Remove tester Signed-off-by: js_park --- compare.py | 198 ----------------------------------------------------- 1 file changed, 198 deletions(-) delete mode 100644 compare.py diff --git a/compare.py b/compare.py deleted file mode 100644 index 62981dfe19ff..000000000000 --- a/compare.py +++ /dev/null @@ -1,198 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from datetime import datetime -from itertools import product - -import regex as re -import torch - -from vllm.v1.sample.ops.topk_topp_sampler import ( - apply_top_k_top_p, - apply_top_k_top_p_triton, -) - - -def g_str(s): - return "\033[32m" + s + "\033[0m" - - -def r_str(s): - return "\033[31m" + s + "\033[0m" - - -def y_str(s): - return "\033[33m" + s + "\033[0m" - - -def b_str(s): - return "\033[34m" + s + "\033[0m" - - -def print_to_log(s, log_file): - print(s) - # Remove the color codes - s = re.sub(r"\033\[[0-9;]*m", "", s) - with open(log_file, "a") as f: - f.write(s + "\n") - - -def test_accuracy(logits, k, p, func_list, log_file): - input_logit_list = [logits.clone().detach() for i in range(len(func_list))] - original_logits = func_list[0](input_logit_list[0], k, p) - output_correct_list = [] - for i in range(1, len(func_list)): - output_logits = func_list[i](input_logit_list[i], k, p) - - torch.cuda.synchronize() - is_correct = True - # original_logits_bin = original_logits.view(torch.int32) - # output_logits_bin = output_logits.view(torch.int32) - # is_correct = torch.all(original_logits_bin == output_logits_bin) - is_correct = is_correct and torch.allclose( - output_logits, original_logits - ) - output_correct_list.append(is_correct) - func_name = func_list[i].__name__ - - if not is_correct: - print_to_log( - r_str("Error: logits are not close on " + f"{func_name}"), - log_file, - ) - output_logits = func_list[i](logits, k, p) - error_mask = torch.abs(output_logits - original_logits) > 1e-12 - error_rows = torch.where(error_mask)[0] - error_rows = torch.unique(error_rows) - num_error_rows = error_rows.shape[0] - error_cols = torch.where(error_mask)[1] - error_cols = torch.unique(error_cols) - num_error_cols = error_cols.shape[0] - print_to_log( - f"num_error_rows: {num_error_rows} - {error_rows}\n" + \ - f"num_error_cols: {num_error_cols} - {error_cols}", - log_file, - ) - row_to_show = 5 if num_error_rows > 5 else num_error_rows - logits_to_show = torch.sort( - output_logits[error_rows], descending=True - ).values - - logits_to_show = logits_to_show[:row_to_show, :20] - print_to_log(f"logits: {logits_to_show}", log_file) - original_logits_to_show = torch.sort( - original_logits[error_rows], descending=True - ).values - original_logits_to_show = original_logits_to_show[:row_to_show, :20] - print_to_log(f"original_logits: {original_logits_to_show}", log_file) - raise ValueError("Logits are not close") - return output_correct_list - - -def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): - # We must clone the logits for each run to avoid modifying the original - warmup_tensor = [logits.clone().detach() for _ in range(num_warmup)] - for _ in range(num_warmup): - test_func(warmup_tensor[_], k, p) - torch.cuda.synchronize() - - input_logits = [logits.clone().detach() for _ in range(num_runs)] - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - for _ in range(num_runs): - input_logits[_] = test_func(input_logits[_], k, p) - end.record() - torch.cuda.synchronize() - time_taken = start.elapsed_time(end) / num_runs - - return time_taken - - -if __name__ == "__main__": - date_str = datetime.now().strftime("%Y%m%d_%H%M%S") - - batch_size_list = [16, 32, 64, 128, 256, 512, 1024] - vocab_size_list = [16384, 65536, 102400, 128256] - p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.95, 0.99] - # k_list = [None for _ in range(100)] - k_list = [None, "RAND", 5, 50, 200, 500, 3000] - func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] - - log_file = f"triton_topk_topp_test_{date_str}.log" - csv_file = f"triton_topk_topp_test_{date_str}.csv" - - print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) - print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) - print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) - print_to_log(y_str("p_list:") + f"{p_list}", log_file) - print_to_log(y_str("k_list:") + f"{k_list}", log_file) - - print_to_log(y_str("log_file:") + f"{log_file}", log_file) - print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) - - with open(csv_file, "w") as f: - f.write( - "dist_generator,batch_size,vocab_size,p,k,triton_correct," - "torch_time_taken,triton_time_taken,triton_speedup\n" - ) - - for batch_size, vocab_size, p, k in product( - batch_size_list, vocab_size_list, p_list, k_list - ): - if p is None and k is None: - continue - - logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 - logits_list = [("RANDN", logits_randn)] - - if p == "RAND": - p_tensor = torch.rand((batch_size,), device="cuda") * 0.98 + 0.01 - elif p is not None: - p_tensor = torch.full((batch_size,), p, device="cuda") - else: - p_tensor = None - - if k == "RAND": - k_tensor = torch.randint(1, vocab_size, (batch_size,), device="cuda") - elif k is not None: - k_tensor = torch.full((batch_size,), k, device="cuda") - else: - k_tensor = None - - for dist_generator, logits in logits_list: - print_to_log(y_str("--------------------------------"), log_file) - print_to_log( - g_str("Testing ") - + f"{dist_generator}" - + y_str(" with batch_size: ") - + f"{batch_size}" - + y_str(" vocab_size: ") - + f"{vocab_size}" - + y_str(" p: ") - + f"{p}" - + y_str(" k: ") - + f"{k}", - log_file, - ) - correct_list = \ - test_accuracy(logits, k_tensor, p_tensor, func_list, log_file) - time_list = [] - for func in func_list: - time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) - time_list.append(time_taken) - print_to_log(b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) - print_to_log(b_str("triton_time_taken: ") + f"{time_list[1]}", log_file) - print_to_log( - g_str("test Speedup over Torch: ") - + f"{time_list[0] / time_list[1]:.8f}x", - log_file, - ) - with open(csv_file, "a") as f: - f.write( - f"{dist_generator},{batch_size},{vocab_size},{p},{k}," - f"{correct_list[0]},{time_list[0]},{time_list[1]}," - f"{time_list[0] / time_list[1]:.8f}\n" - ) - print_to_log(y_str("--------------------------------\n"), log_file) From 204c221f7253e770d99fe251016341bc332214fe Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 02:55:32 -0800 Subject: [PATCH 59/99] Bugfix Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index ba0467c321ab..215aeda984fa 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -358,21 +358,19 @@ def apply_top_k_top_p_triton( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, - debug: bool = False, ) -> torch.Tensor: """ Uses pivot-based algorithm to filter --> sort """ - # Fallback to torch for small batch sizes - if logits.shape[0] < 16: - return apply_top_k_top_p(logits, k, p) if k is None and p is None: return logits - elif p is None and k is not None: + if p is None and k is not None: return apply_top_k_only_triton(logits, k) - else: - return apply_top_k_top_p_filtered(logits, k, p) + # Fallback to torch for small batch sizes for top-p + if logits.shape[0] < 16 or logits.shape[1] < 32768: + return apply_top_k_top_p(logits, k, p) + return apply_top_k_top_p_filtered(logits, k, p) @triton.jit @@ -769,7 +767,6 @@ def top_k_top_p_filter( logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) # Gather filtered values - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=f_mask) tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=f_mask) tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=f_mask) @@ -806,7 +803,9 @@ def apply_top_k_top_p_filtered( buffer = torch.empty( (NUM_PROGRAMS, vocab_size), device=logits.device, dtype=torch.float32 ) - p_filter = int(max_k * 1.2) if k is not None else int(vocab_size / 32) + p_filter = ( + min(int(max_k * 1.2), vocab_size - 1) if k is not None else int(vocab_size / 32) + ) filtered_logits = torch.full( (batch_size, p_filter), -float("inf"), device=logits.device ) From e262fcb56072d5e761a7ce4c142cb86d806c2411 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 03:02:16 -0800 Subject: [PATCH 60/99] Test file removed. Signed-off-by: js_park --- vllm/v1/sample/ops/test1.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 vllm/v1/sample/ops/test1.py diff --git a/vllm/v1/sample/ops/test1.py b/vllm/v1/sample/ops/test1.py deleted file mode 100644 index fb78fb676159..000000000000 --- a/vllm/v1/sample/ops/test1.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch - -# Create a randomly initialized 5x3 tensor -x = torch.rand(5, 3) -print("Random Tensor:\n", x) - -# Check if CUDA is available and print the result -cuda_available = torch.cuda.is_available() -print("\nCUDA available:", cuda_available) - -# If CUDA is available, you can also try moving a tensor to the GPU -if cuda_available: - device = torch.device("cuda") - y = torch.ones(2, 2, device=device) - print("\nTensor on GPU:\n", y) \ No newline at end of file From 5e6dc79f40bc7ebc9ea5922063544cdf1b923ab8 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 03:07:42 -0800 Subject: [PATCH 61/99] Typos Signed-off-by: js_park --- vllm/envs.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 20 +++++++------------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index d931f0ef2925..cb9f82c02c62 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -46,7 +46,7 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: str | None = None VLLM_USE_FLASHINFER_SAMPLER: bool | None = None - VLLM_USE_TRITON_SAMPLER: bool = False + VLLM_USE_TRITON_SAMPLER: bool | None = None VLLM_PP_LAYER_PARTITION: str | None = None VLLM_CPU_KVCACHE_SPACE: int | None = 0 VLLM_ATTENTION_BACKEND: str | None = None diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 215aeda984fa..5ccd35d2be4d 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -50,6 +50,10 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: if envs.VLLM_USE_TRITON_SAMPLER: logger.info_once("Using Triton for top-p & top-k sampling.") + if envs.VLLM_USE_FLASHINFER_SAMPLER: + logger.info_once( + "Overriding FlashInfer with Triton for top-p & top-k sampling." + ) self.forward = self.forward_triton else: logger.warning_once( @@ -231,16 +235,6 @@ def apply_top_k_only( The logits tensor may be updated in-place. """ - if k is None: - return logits - max_top_k = k.max().item() - - # --- FIX: Handle k=0 edge case --- - # If the max k is 0, all rows are 0. Mask everything and exit. - if max_top_k == 0: - logits.fill_(-float("inf")) - return logits - no_top_k_mask = k == logits.shape[1] # Set non-top-k rows to 1 so that we can gather. k = k.masked_fill(no_top_k_mask, 1) @@ -265,7 +259,7 @@ def random_sample( causes CPU-GPU synchronization. """ q = torch.empty_like(probs) - # VOCAB_SIZEOTE(woosuk): To batch-process the requests without their own seeds, + # NOTE(woosuk): To batch-process the requests without their own seeds, # which is the common case, we first assume that every request does # not have its own seed. Then, we overwrite the values for the requests # that have their own seeds. @@ -291,11 +285,11 @@ def flashinfer_sample( However, this function is faster because it avoids sorting the logits tensor via rejection sampling. - VOCAB_SIZEOTE: The outputs of this function do not necessarily match the outputs of + NOTE: The outputs of this function do not necessarily match the outputs of the `random_sample` function. It only guarantees that the outputs are statistically equivalent. - VOCAB_SIZEOTE: This function includes CPU-GPU synchronization, while `random_sample` + NOTE: This function includes CPU-GPU synchronization, while `random_sample` does not. Call this function at the end of the forward pass to minimize the synchronization overhead. """ From 091b5188c403dc14746fbd37341159e81fbd3ff9 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 03:08:43 -0800 Subject: [PATCH 62/99] Typos Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 5ccd35d2be4d..4d98e9f701e2 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -200,6 +200,7 @@ def apply_top_k_top_p( # Avoid sorting vocab for top-k only case. return apply_top_k_only(logits, k) + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) if k is not None: From 5fc986e244c312a04034842d57b8fa867ff26f64 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 03:09:39 -0800 Subject: [PATCH 63/99] Typos Signed-off-by: js_park --- vllm/envs.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index cb9f82c02c62..892400619181 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -49,9 +49,6 @@ VLLM_USE_TRITON_SAMPLER: bool | None = None VLLM_PP_LAYER_PARTITION: str | None = None VLLM_CPU_KVCACHE_SPACE: int | None = 0 - VLLM_ATTENTION_BACKEND: str | None = None - VLLM_PP_LAYER_PARTITION: str | None = None - VLLM_CPU_KVCACHE_SPACE: int | None = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_MOE_PREPACK: bool = True From 02d446bb558555bc8a2990d0fac659a08379c3e2 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 03:11:27 -0800 Subject: [PATCH 64/99] Typos Signed-off-by: js_park --- vllm/envs.py | 2 +- vllm/v1/sample/ops/topk_topp_sampler.py | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 892400619181..0c5123b57afe 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -650,7 +650,7 @@ def get_vllm_port() -> int | None: ) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, - # If set, vllm will use triton sampler, will override flashinfer sampler. + # If set, vllm will use triton sampler. This will override the flashinfer sampler. "VLLM_USE_TRITON_SAMPLER": lambda: bool( int(os.environ.get("VLLM_USE_TRITON_SAMPLER", "0")) ) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 4d98e9f701e2..b307a1d7c4a7 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -49,18 +49,13 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: self.forward = self.forward_native if envs.VLLM_USE_TRITON_SAMPLER: - logger.info_once("Using Triton for top-p & top-k sampling.") if envs.VLLM_USE_FLASHINFER_SAMPLER: logger.info_once( - "Overriding FlashInfer with Triton for top-p & top-k sampling." + "Overriding FlashInfer top-p & top-k sampling with " + "Triton top-p & top-k sampling." ) - self.forward = self.forward_triton - else: - logger.warning_once( - "Triton top-p/top-k sampling is available but disabled " - "by default. Set VLLM_USE_TRITON_SAMPLER=1 to opt in " - "after verifying accuracy for your workloads." - ) + else: + logger.info_once("Using Triton for top-p & top-k sampling.") self.forward = self.forward_native elif current_platform.is_cpu(): arch = current_platform.get_cpu_architecture() From d0f02f6c32b1093364e9e8feb0f05c88a553e64c Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 03:12:17 -0800 Subject: [PATCH 65/99] Typos Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index b307a1d7c4a7..b02a1025c946 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -535,7 +535,6 @@ def top_k_top_p_filter( LOGITS, DO_TOP_K, K, - P, P_FIL, BUFFER, BATCH_SIZE, @@ -816,7 +815,6 @@ def apply_top_k_top_p_filtered( logits, (k is not None), k if k is not None else filtered_indices, - p, p_filter, buffer, batch_size, From 152bc320de7a73d2b373f24e28ad3f0e627cc2a8 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 16 Nov 2025 16:01:35 -0800 Subject: [PATCH 66/99] Bugfixes Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index b02a1025c946..767886672459 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -401,7 +401,7 @@ def _topk_triton_kernel( sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_valid std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - percentile = tl.cast(k * 2.0 / VOCAB_SIZE * 100, tl.uint32) + 1 + percentile = tl.cast(k * 1.6 / VOCAB_SIZE * 100 + 1, tl.uint32) percentile = tl.minimum(percentile, 99) sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) outlier_pivot = avg_logit + sigma * std_logit @@ -489,7 +489,7 @@ def _topk_triton_kernel( offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - mask = logits_blk > k_pivot + mask = (logits_blk > k_pivot) & mask_n logits_blk = tl.where(mask, logits_blk, -float("inf")) tl.store(OUTPUT_ROW + offs_n, logits_blk, mask=mask_n) @@ -542,6 +542,7 @@ def top_k_top_p_filter( FILTERED_LOGITS, FILTERED_INDICES, FILTERED_PROBS, + NUM_FILTERED, PERCENTILE_TO_STD_TABLE, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -570,7 +571,7 @@ def top_k_top_p_filter( sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_mask std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - percentile = tl.cast(P_FIL * 2.0 / VOCAB_SIZE * 100 + 4, tl.uint32) + percentile = tl.cast(P_FIL * 1.6 / VOCAB_SIZE * 100 + 1, tl.uint32) percentile = tl.minimum(percentile, 99) sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) outlier_pivot = avg_logit + sigma * std_logit @@ -752,7 +753,7 @@ def top_k_top_p_filter( f_mask = keep_mask & (cpos < P_FIL) write_idx = tl.where(f_mask, cpos, 0) - top_k_mask = logits_blk > k_pivot + top_k_mask = (logits_blk > k_pivot) & mask_n logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) # Gather filtered values @@ -763,6 +764,7 @@ def top_k_top_p_filter( sum_excluded_probs += tl.sum(probs_blk * ((~f_mask) & mask_n)) write_pos += tl.sum(f_mask, dtype=tl.int32) tl.store(SUM_EXCLUDED_PROBS + row_id, sum_excluded_probs) + tl.store(NUM_FILTERED + row_id, write_pos) def apply_top_k_top_p_filtered( @@ -777,11 +779,7 @@ def apply_top_k_top_p_filtered( # If k is too large, speedup is not significant as the filtered set is large. max_k = k.max().item() if k is not None else 0 - # Probability value is not guaranteed to be equivalent to the PyTorch implementation - # in the distribution tail due to floating point non-associativity. - # We avoid high p values to avoid this accuracy issue. - max_p = p.max().item() if p is not None else 0 - if max_k > vocab_size / 10 or max_p > 0.95: + if max_k > vocab_size / 10: return apply_top_k_top_p(logits, k, p) BLOCK_SIZE = 8192 @@ -807,6 +805,7 @@ def apply_top_k_top_p_filtered( sum_excluded_probs = torch.zeros( (batch_size,), device=logits.device, dtype=torch.float32 ) + num_filtered = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) PERCENTILE_TO_STD_TABLE = torch.tensor( _PERCENTILE_TO_STD_TABLE, device=logits.device ) @@ -822,6 +821,7 @@ def apply_top_k_top_p_filtered( filtered_logits, filtered_indices, filtered_probs, + num_filtered, PERCENTILE_TO_STD_TABLE, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, @@ -829,7 +829,7 @@ def apply_top_k_top_p_filtered( num_stages=NUM_STAGES, ) - if torch.any(sum_excluded_probs >= p): + if torch.any(sum_excluded_probs >= p) or torch.any(num_filtered != p_filter): return apply_top_k_top_p(logits, k, p) logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) From db9859f51fb351b5c0e63d48419b1eb759dad31f Mon Sep 17 00:00:00 2001 From: js_park Date: Mon, 17 Nov 2025 16:53:52 -0800 Subject: [PATCH 67/99] Deduplication Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 165 ++++++++++++++++++++---- 1 file changed, 142 insertions(+), 23 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 767886672459..60a9380f2242 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -438,8 +438,10 @@ def _topk_triton_kernel( # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 - k_pivot = -float("inf") - while k_pivot == -float("inf"): + k_pivot = float("inf") + if k == VOCAB_SIZE: + k_pivot = -float("inf") + while k_pivot == float("inf"): k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range @@ -532,6 +534,12 @@ def apply_top_k_only_triton( @triton.jit def top_k_top_p_filter( + DO_DEDUPLICATE, + CURRENT_NUM_DUPLICATES, + NUM_DUPLICATES, + MIN_LARGER_P_FIL_PIVOT, + FILTERED_LOGITS_NO_TOP_K, + PFIL_PIVOT, LOGITS, DO_TOP_K, K, @@ -615,9 +623,17 @@ def top_k_top_p_filter( # Second passes: Quaternary search for pivots (nlog_4(n)) num_iters = 0 - k_pivot = -float("inf") - p_fil_pivot = -float("inf") - while k_pivot == -float("inf") or p_fil_pivot == -float("inf"): + k_pivot = float("inf") + p_fil_pivot = float("inf") + # For duplicate pivot detection + min_larger_p_fil_pivot = float("inf") + num_min_larger_p_fil_pivot = tl.zeros((), dtype=tl.uint32) + do_deduplicate = tl.zeros((), dtype=tl.int32) + if k == VOCAB_SIZE: + k_pivot = -float("inf") + if P_FIL == VOCAB_SIZE: + p_fil_pivot = -float("inf") + while k_pivot == float("inf") or p_fil_pivot == float("inf"): k_pivot_0 = (k_max_range - k_min_range) * 1.0 / 4.0 + k_min_range k_pivot_1 = (k_max_range - k_min_range) * 2.0 / 4.0 + k_min_range k_pivot_2 = (k_max_range - k_min_range) * 3.0 / 4.0 + k_min_range @@ -653,8 +669,26 @@ def top_k_top_p_filter( p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) + larger_p_fil_pivot = tl.where( + logits_blk > p_fil_pivot, logits_blk, -float("inf") + ) + min_larger_p_fil_pivot = tl.minimum( + min_larger_p_fil_pivot, tl.min(larger_p_fil_pivot) + ) + + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") + ) + min_larger_p_fil_pivot_mask = ( + tl.abs(logits_blk - min_larger_p_fil_pivot) < 1e-12 + ) & mask_n + num_min_larger_p_fil_pivot += tl.sum(min_larger_p_fil_pivot_mask) + # Check if any of the pivots are equal to k - if k_pivot == -float("inf"): + if k_pivot == float("inf"): if k_pivots_num_0 == k: k_pivot = k_pivot_0 elif k_pivots_num_1 == k: @@ -676,7 +710,7 @@ def top_k_top_p_filter( k_max_range = k_pivot_2 # Check if any of the pivots are equal to P_FIL - if p_fil_pivot == -float("inf"): + if p_fil_pivot == float("inf"): if p_fil_pivots_num_0 == P_FIL: p_fil_pivot = p_fil_pivot_0 elif p_fil_pivots_num_1 == P_FIL: @@ -685,6 +719,14 @@ def top_k_top_p_filter( p_fil_pivot = p_fil_pivot_2 # If none of the pivots are equal to P_FIL, we update the range elif p_fil_pivots_num_2 > P_FIL: + if p_fil_pivots_num_2 - num_min_larger_p_fil_pivot < P_FIL: + # Duplicate pivot detected + p_fil_pivot = p_fil_pivot_2 + # Number of duplicate pivots to keep in the filtered set + num_min_larger_p_fil_pivot = num_min_larger_p_fil_pivot - ( + p_fil_pivots_num_2 - P_FIL + ) + do_deduplicate = 1 p_fil_min_range = p_fil_pivot_2 elif p_fil_pivots_num_1 > P_FIL: p_fil_min_range = p_fil_pivot_1 @@ -699,15 +741,15 @@ def top_k_top_p_filter( num_iters += 1 if num_iters >= 32 or ( - (tl.abs(k_min_range - k_max_range) < 1e-16 and k_pivot != -float("inf")) + (tl.abs(k_min_range - k_max_range) < 1e-16 and k_pivot != float("inf")) and ( tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 - and p_fil_pivot != -float("inf") + and p_fil_pivot != float("inf") ) ): - if k_pivot == -float("inf"): + if k_pivot == float("inf"): k_pivot = k_pivot_0 - if p_fil_pivot == -float("inf"): + if p_fil_pivot == float("inf"): p_fil_pivot = p_fil_pivot_0 # Third pass: Calculate exp logits and sum with top-k mask @@ -739,19 +781,38 @@ def top_k_top_p_filter( # Fifth pass : Gather filtered values with top-k mask write_pos = tl.zeros((), dtype=tl.int32) sum_excluded_probs = tl.zeros((), dtype=tl.float32) - FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * P_FIL - FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * P_FIL - FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * P_FIL + current_num_duplicates = tl.zeros((), dtype=tl.uint32) + FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * (P_FIL + 5) + FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * (P_FIL + 5) + FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * (P_FIL + 5) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - keep_mask = (logits_blk > p_fil_pivot) & mask_n + keep_mask = logits_blk > p_fil_pivot + if do_deduplicate > 0: + duplicate_mask = ( + tl.abs(logits_blk - min_larger_p_fil_pivot) < 1e-12 + ) & mask_n + duplicate_count = tl.cumsum(duplicate_mask) + current_num_duplicates + duplicate_mask = duplicate_mask & ( + duplicate_count <= num_min_larger_p_fil_pivot + ) + keep_mask = keep_mask & duplicate_mask + current_num_duplicates += tl.sum(duplicate_mask) + keep_mask = keep_mask & mask_n cpos = tl.cumsum(keep_mask) - 1 + write_pos - f_mask = keep_mask & (cpos < P_FIL) - write_idx = tl.where(f_mask, cpos, 0) + f_mask = keep_mask + write_idx = tl.where(f_mask, cpos, P_FIL) + + FILTERED_LOGITS_NO_TOP_K_ROW = FILTERED_LOGITS_NO_TOP_K + row_id * ( + P_FIL + 5 + ) + tl.store( + FILTERED_LOGITS_NO_TOP_K_ROW + write_idx, logits_blk, mask=keep_mask + ) top_k_mask = (logits_blk > k_pivot) & mask_n logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) @@ -761,10 +822,15 @@ def top_k_top_p_filter( tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=f_mask) tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=f_mask) - sum_excluded_probs += tl.sum(probs_blk * ((~f_mask) & mask_n)) + sum_excluded_probs += tl.sum(probs_blk * (keep_mask & (~f_mask) & mask_n)) write_pos += tl.sum(f_mask, dtype=tl.int32) + tl.store(PFIL_PIVOT + row_id, p_fil_pivot) tl.store(SUM_EXCLUDED_PROBS + row_id, sum_excluded_probs) tl.store(NUM_FILTERED + row_id, write_pos) + tl.store(CURRENT_NUM_DUPLICATES + row_id, current_num_duplicates) + tl.store(NUM_DUPLICATES + row_id, num_min_larger_p_fil_pivot) + tl.store(MIN_LARGER_P_FIL_PIVOT + row_id, min_larger_p_fil_pivot) + tl.store(DO_DEDUPLICATE + row_id, do_deduplicate) def apply_top_k_top_p_filtered( @@ -791,26 +857,50 @@ def apply_top_k_top_p_filtered( (NUM_PROGRAMS, vocab_size), device=logits.device, dtype=torch.float32 ) p_filter = ( - min(int(max_k * 1.2), vocab_size - 1) if k is not None else int(vocab_size / 32) + # vocab_size + min(int(max_k * 1.5), vocab_size - 1) if k is not None else int(vocab_size / 20) ) filtered_logits = torch.full( - (batch_size, p_filter), -float("inf"), device=logits.device + (batch_size, p_filter + 5), -float("inf"), device=logits.device ) + filtered_logits_no_top_k = torch.full( + (batch_size, p_filter + 5), -float("inf"), device=logits.device + ) + filtered_indices = torch.full( - (batch_size, p_filter), p_filter, dtype=torch.int64, device=logits.device + (batch_size, p_filter + 5), p_filter, dtype=torch.int64, device=logits.device ) filtered_probs = torch.full( - (batch_size, p_filter), -float("inf"), device=logits.device + (batch_size, p_filter + 5), -float("inf"), device=logits.device ) sum_excluded_probs = torch.zeros( (batch_size,), device=logits.device, dtype=torch.float32 ) + num_duplicates = torch.zeros( + (batch_size,), device=logits.device, dtype=torch.uint32 + ) + min_larger_p_fil_pivot = torch.zeros( + (batch_size,), device=logits.device, dtype=torch.float32 + ) + current_num_duplicates = torch.zeros( + (batch_size,), device=logits.device, dtype=torch.uint32 + ) + do_deduplicate = torch.zeros( + (batch_size,), device=logits.device, dtype=torch.uint32 + ) num_filtered = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) + pfil_pivot = torch.zeros((batch_size,), device=logits.device, dtype=torch.float32) PERCENTILE_TO_STD_TABLE = torch.tensor( _PERCENTILE_TO_STD_TABLE, device=logits.device ) top_k_top_p_filter[(NUM_PROGRAMS,)]( + do_deduplicate, + current_num_duplicates, + num_duplicates, + min_larger_p_fil_pivot, + filtered_logits_no_top_k, + pfil_pivot, logits, (k is not None), k if k is not None else filtered_indices, @@ -829,19 +919,48 @@ def apply_top_k_top_p_filtered( num_stages=NUM_STAGES, ) - if torch.any(sum_excluded_probs >= p) or torch.any(num_filtered != p_filter): + print(f"p {p}") + print(f"do_deduplicate {do_deduplicate}") + print(f"current_num_duplicates {current_num_duplicates}") + print(f"num_duplicates {num_duplicates}") + print(f"min_larger_p_fil_pivot {min_larger_p_fil_pivot}") + print(f"num_filtered {num_filtered}") + print(f"pfil_pivot {pfil_pivot}") + print(f"Filtered logits no top k {filtered_logits_no_top_k}") + print(f"Filtered logits {filtered_logits}") + print(f"Filtered indices {filtered_indices}") + print(f"Filtered probs {filtered_probs}") + print(f"Sum excluded probs {sum_excluded_probs}") + + filtered_logits = filtered_logits[:, :p_filter] + filtered_indices = filtered_indices[:, :p_filter] + filtered_probs = filtered_probs[:, :p_filter] + + if torch.any(num_filtered != p_filter): + print(f"num_filtered != p_filter: {num_filtered} != {p_filter}") + + if torch.any(sum_excluded_probs >= p): return apply_top_k_top_p(logits, k, p) logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) sorted_probs = torch.gather(filtered_probs, -1, sort_indices) + print("logits_sort", logits_sort) + print("sorted_probs", sorted_probs) + print("logits_sort_indices", logits_sort_indices) + sorted_probs[:, 0] = sorted_probs[:, 0] + sum_excluded_probs probs_sum = torch.cumsum(sorted_probs, dim=-1) + print("probs_sum", probs_sum) top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=-1)) + print("threashold", 1 - p.unsqueeze(dim=-1)) top_p_mask[:, -1] = False + print("top_p_mask", top_p_mask) logits_sort.masked_fill_(top_p_mask, -float("inf")) + print("logits_sort_masked", logits_sort) + logits.fill_(-float("inf")) logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) return logits From b936c94d329bc145e6dc3707e3e2786cdec3d079 Mon Sep 17 00:00:00 2001 From: js_park Date: Mon, 17 Nov 2025 17:27:38 -0800 Subject: [PATCH 68/99] Duplication search bugfix Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 60a9380f2242..0d935bc40e57 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -654,6 +654,9 @@ def top_k_top_p_filter( p_fil_pivots_num_1 = tl.zeros((), dtype=tl.uint32) p_fil_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + min_larger_p_fil_pivot = float("inf") + num_min_larger_p_fil_pivot = tl.zeros((), dtype=tl.uint32) + for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < search_range @@ -669,11 +672,11 @@ def top_k_top_p_filter( p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) - larger_p_fil_pivot = tl.where( - logits_blk > p_fil_pivot, logits_blk, -float("inf") + larger_p_fil_pivot_2 = tl.where( + logits_blk > p_fil_pivot_2, logits_blk, float("inf") ) min_larger_p_fil_pivot = tl.minimum( - min_larger_p_fil_pivot, tl.min(larger_p_fil_pivot) + min_larger_p_fil_pivot, tl.min(larger_p_fil_pivot_2) ) for i in range(0, search_iters): From 3784e60333d9a474bbdb776dcb5dda11a20c6661 Mon Sep 17 00:00:00 2001 From: js_park Date: Mon, 17 Nov 2025 20:33:16 -0800 Subject: [PATCH 69/99] Bugfixes Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 116 ++++++++++++++++++------ 1 file changed, 90 insertions(+), 26 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 0d935bc40e57..531762223fde 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -627,8 +627,15 @@ def top_k_top_p_filter( p_fil_pivot = float("inf") # For duplicate pivot detection min_larger_p_fil_pivot = float("inf") - num_min_larger_p_fil_pivot = tl.zeros((), dtype=tl.uint32) + num_deduplicate_to_keep = tl.zeros((), dtype=tl.uint32) do_deduplicate = tl.zeros((), dtype=tl.int32) + num_min_larger_p_fil_pivot = tl.zeros((), dtype=tl.uint32) + min_larger_p_fil_pivot_2 = float("inf") + num_min_larger_p_fil_pivot_2 = tl.zeros((), dtype=tl.uint32) + min_larger_p_fil_pivot_1 = float("inf") + num_min_larger_p_fil_pivot_1 = tl.zeros((), dtype=tl.uint32) + min_larger_p_fil_pivot_0 = float("inf") + num_min_larger_p_fil_pivot_0 = tl.zeros((), dtype=tl.uint32) if k == VOCAB_SIZE: k_pivot = -float("inf") if P_FIL == VOCAB_SIZE: @@ -654,8 +661,13 @@ def top_k_top_p_filter( p_fil_pivots_num_1 = tl.zeros((), dtype=tl.uint32) p_fil_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - min_larger_p_fil_pivot = float("inf") - num_min_larger_p_fil_pivot = tl.zeros((), dtype=tl.uint32) + if p_fil_pivot == float("inf"): + min_larger_p_fil_pivot_2 = float("inf") + num_min_larger_p_fil_pivot_2 = tl.zeros((), dtype=tl.uint32) + min_larger_p_fil_pivot_1 = float("inf") + num_min_larger_p_fil_pivot_1 = tl.zeros((), dtype=tl.uint32) + min_larger_p_fil_pivot_0 = float("inf") + num_min_larger_p_fil_pivot_0 = tl.zeros((), dtype=tl.uint32) for i in range(0, search_iters): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -672,23 +684,49 @@ def top_k_top_p_filter( p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) - larger_p_fil_pivot_2 = tl.where( - logits_blk > p_fil_pivot_2, logits_blk, float("inf") - ) - min_larger_p_fil_pivot = tl.minimum( - min_larger_p_fil_pivot, tl.min(larger_p_fil_pivot_2) - ) + if p_fil_pivot == float("inf"): + larger_p_fil_pivot = tl.where( + (logits_blk > p_fil_pivot_2) & mask_n, logits_blk, float("inf") + ) + min_larger_p_fil_pivot_2 = tl.minimum( + min_larger_p_fil_pivot_2, tl.min(larger_p_fil_pivot) + ) - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) - min_larger_p_fil_pivot_mask = ( - tl.abs(logits_blk - min_larger_p_fil_pivot) < 1e-12 - ) & mask_n - num_min_larger_p_fil_pivot += tl.sum(min_larger_p_fil_pivot_mask) + larger_p_fil_pivot = tl.where( + (logits_blk > p_fil_pivot_1) & mask_n, logits_blk, float("inf") + ) + min_larger_p_fil_pivot_1 = tl.minimum( + min_larger_p_fil_pivot_1, tl.min(larger_p_fil_pivot) + ) + + larger_p_fil_pivot = tl.where( + (logits_blk > p_fil_pivot_0) & mask_n, logits_blk, float("inf") + ) + min_larger_p_fil_pivot_0 = tl.minimum( + min_larger_p_fil_pivot_0, tl.min(larger_p_fil_pivot) + ) + + if p_fil_pivot == float("inf"): + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < search_range + logits_blk = tl.load( + search_addr + offs_n, mask=mask_n, other=-float("inf") + ) + min_larger_p_fil_pivot_mask = ( + tl.abs(logits_blk - min_larger_p_fil_pivot_2) < 1e-12 + ) & mask_n + num_min_larger_p_fil_pivot_2 += tl.sum(min_larger_p_fil_pivot_mask) + + min_larger_p_fil_pivot_mask = ( + tl.abs(logits_blk - min_larger_p_fil_pivot_1) < 1e-12 + ) & mask_n + num_min_larger_p_fil_pivot_1 += tl.sum(min_larger_p_fil_pivot_mask) + + min_larger_p_fil_pivot_mask = ( + tl.abs(logits_blk - min_larger_p_fil_pivot_0) < 1e-12 + ) & mask_n + num_min_larger_p_fil_pivot_0 += tl.sum(min_larger_p_fil_pivot_mask) # Check if any of the pivots are equal to k if k_pivot == float("inf"): @@ -722,19 +760,41 @@ def top_k_top_p_filter( p_fil_pivot = p_fil_pivot_2 # If none of the pivots are equal to P_FIL, we update the range elif p_fil_pivots_num_2 > P_FIL: - if p_fil_pivots_num_2 - num_min_larger_p_fil_pivot < P_FIL: + if p_fil_pivots_num_2 - num_min_larger_p_fil_pivot_2 < P_FIL: # Duplicate pivot detected p_fil_pivot = p_fil_pivot_2 # Number of duplicate pivots to keep in the filtered set - num_min_larger_p_fil_pivot = num_min_larger_p_fil_pivot - ( + num_deduplicate_to_keep = num_min_larger_p_fil_pivot_2 - ( p_fil_pivots_num_2 - P_FIL ) + min_larger_p_fil_pivot = min_larger_p_fil_pivot_2 + num_min_larger_p_fil_pivot = num_min_larger_p_fil_pivot_2 do_deduplicate = 1 p_fil_min_range = p_fil_pivot_2 elif p_fil_pivots_num_1 > P_FIL: p_fil_min_range = p_fil_pivot_1 + if p_fil_pivots_num_1 - num_min_larger_p_fil_pivot_1 < P_FIL: + # Duplicate pivot detected + p_fil_pivot = p_fil_pivot_1 + # Number of duplicate pivots to keep in the filtered set + num_deduplicate_to_keep = num_min_larger_p_fil_pivot_1 - ( + p_fil_pivots_num_1 - P_FIL + ) + min_larger_p_fil_pivot = min_larger_p_fil_pivot_1 + num_min_larger_p_fil_pivot = num_min_larger_p_fil_pivot_1 + do_deduplicate = 1 elif p_fil_pivots_num_0 > P_FIL: p_fil_min_range = p_fil_pivot_0 + if p_fil_pivots_num_0 - num_min_larger_p_fil_pivot_0 < P_FIL: + # Duplicate pivot detected + p_fil_pivot = p_fil_pivot_0 + # Number of duplicate pivots to keep in the filtered set + num_deduplicate_to_keep = num_min_larger_p_fil_pivot_0 - ( + p_fil_pivots_num_0 - P_FIL + ) + min_larger_p_fil_pivot = min_larger_p_fil_pivot_0 + num_min_larger_p_fil_pivot = num_min_larger_p_fil_pivot_0 + do_deduplicate = 1 if p_fil_pivots_num_0 < P_FIL: p_fil_max_range = p_fil_pivot_0 elif p_fil_pivots_num_1 < P_FIL: @@ -799,12 +859,16 @@ def top_k_top_p_filter( duplicate_mask = ( tl.abs(logits_blk - min_larger_p_fil_pivot) < 1e-12 ) & mask_n + duplicate_count = tl.cumsum(duplicate_mask) + current_num_duplicates - duplicate_mask = duplicate_mask & ( - duplicate_count <= num_min_larger_p_fil_pivot + duplicate_remove_mask = duplicate_mask & ( + duplicate_count > num_deduplicate_to_keep ) - keep_mask = keep_mask & duplicate_mask - current_num_duplicates += tl.sum(duplicate_mask) + keep_mask = keep_mask & (~duplicate_remove_mask) + current_num_duplicates += tl.sum( + duplicate_mask & (~duplicate_remove_mask) + ) + keep_mask = keep_mask & mask_n cpos = tl.cumsum(keep_mask) - 1 + write_pos f_mask = keep_mask @@ -861,7 +925,7 @@ def apply_top_k_top_p_filtered( ) p_filter = ( # vocab_size - min(int(max_k * 1.5), vocab_size - 1) if k is not None else int(vocab_size / 20) + min(int(max_k * 1.5), vocab_size - 1) if k is not None else int(vocab_size / 32) ) filtered_logits = torch.full( (batch_size, p_filter + 5), -float("inf"), device=logits.device From b0b6253c429d5d36fde2941562a6ddd42f86616e Mon Sep 17 00:00:00 2001 From: js_park Date: Tue, 18 Nov 2025 00:30:06 -0800 Subject: [PATCH 70/99] PyTorch sort permutes the order of duplicate values when sorting. When two logits have exactly same value, it is not guaranteed which one will be included in the final top-p if they fall on between the top-p threshold. Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 38 +++++++++---------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 531762223fde..bd98d24ce9ea 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -535,7 +535,7 @@ def apply_top_k_only_triton( @triton.jit def top_k_top_p_filter( DO_DEDUPLICATE, - CURRENT_NUM_DUPLICATES, + NUM_DUPLICATES_REMOVED, NUM_DUPLICATES, MIN_LARGER_P_FIL_PIVOT, FILTERED_LOGITS_NO_TOP_K, @@ -627,7 +627,7 @@ def top_k_top_p_filter( p_fil_pivot = float("inf") # For duplicate pivot detection min_larger_p_fil_pivot = float("inf") - num_deduplicate_to_keep = tl.zeros((), dtype=tl.uint32) + num_duplicate_to_remove = tl.zeros((), dtype=tl.uint32) do_deduplicate = tl.zeros((), dtype=tl.int32) num_min_larger_p_fil_pivot = tl.zeros((), dtype=tl.uint32) min_larger_p_fil_pivot_2 = float("inf") @@ -764,9 +764,7 @@ def top_k_top_p_filter( # Duplicate pivot detected p_fil_pivot = p_fil_pivot_2 # Number of duplicate pivots to keep in the filtered set - num_deduplicate_to_keep = num_min_larger_p_fil_pivot_2 - ( - p_fil_pivots_num_2 - P_FIL - ) + num_duplicate_to_remove = p_fil_pivots_num_2 - P_FIL min_larger_p_fil_pivot = min_larger_p_fil_pivot_2 num_min_larger_p_fil_pivot = num_min_larger_p_fil_pivot_2 do_deduplicate = 1 @@ -777,9 +775,7 @@ def top_k_top_p_filter( # Duplicate pivot detected p_fil_pivot = p_fil_pivot_1 # Number of duplicate pivots to keep in the filtered set - num_deduplicate_to_keep = num_min_larger_p_fil_pivot_1 - ( - p_fil_pivots_num_1 - P_FIL - ) + num_duplicate_to_remove = p_fil_pivots_num_1 - P_FIL min_larger_p_fil_pivot = min_larger_p_fil_pivot_1 num_min_larger_p_fil_pivot = num_min_larger_p_fil_pivot_1 do_deduplicate = 1 @@ -789,9 +785,7 @@ def top_k_top_p_filter( # Duplicate pivot detected p_fil_pivot = p_fil_pivot_0 # Number of duplicate pivots to keep in the filtered set - num_deduplicate_to_keep = num_min_larger_p_fil_pivot_0 - ( - p_fil_pivots_num_0 - P_FIL - ) + num_duplicate_to_remove = p_fil_pivots_num_0 - P_FIL min_larger_p_fil_pivot = min_larger_p_fil_pivot_0 num_min_larger_p_fil_pivot = num_min_larger_p_fil_pivot_0 do_deduplicate = 1 @@ -844,7 +838,7 @@ def top_k_top_p_filter( # Fifth pass : Gather filtered values with top-k mask write_pos = tl.zeros((), dtype=tl.int32) sum_excluded_probs = tl.zeros((), dtype=tl.float32) - current_num_duplicates = tl.zeros((), dtype=tl.uint32) + num_duplicates_removed = tl.zeros((), dtype=tl.uint32) FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * (P_FIL + 5) FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * (P_FIL + 5) FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * (P_FIL + 5) @@ -860,14 +854,12 @@ def top_k_top_p_filter( tl.abs(logits_blk - min_larger_p_fil_pivot) < 1e-12 ) & mask_n - duplicate_count = tl.cumsum(duplicate_mask) + current_num_duplicates + duplicate_count = tl.cumsum(duplicate_mask) + num_duplicates_removed duplicate_remove_mask = duplicate_mask & ( - duplicate_count > num_deduplicate_to_keep + duplicate_count <= num_duplicate_to_remove ) keep_mask = keep_mask & (~duplicate_remove_mask) - current_num_duplicates += tl.sum( - duplicate_mask & (~duplicate_remove_mask) - ) + num_duplicates_removed += tl.sum(duplicate_remove_mask) keep_mask = keep_mask & mask_n cpos = tl.cumsum(keep_mask) - 1 + write_pos @@ -894,7 +886,7 @@ def top_k_top_p_filter( tl.store(PFIL_PIVOT + row_id, p_fil_pivot) tl.store(SUM_EXCLUDED_PROBS + row_id, sum_excluded_probs) tl.store(NUM_FILTERED + row_id, write_pos) - tl.store(CURRENT_NUM_DUPLICATES + row_id, current_num_duplicates) + tl.store(NUM_DUPLICATES_REMOVED + row_id, num_duplicates_removed) tl.store(NUM_DUPLICATES + row_id, num_min_larger_p_fil_pivot) tl.store(MIN_LARGER_P_FIL_PIVOT + row_id, min_larger_p_fil_pivot) tl.store(DO_DEDUPLICATE + row_id, do_deduplicate) @@ -937,9 +929,7 @@ def apply_top_k_top_p_filtered( filtered_indices = torch.full( (batch_size, p_filter + 5), p_filter, dtype=torch.int64, device=logits.device ) - filtered_probs = torch.full( - (batch_size, p_filter + 5), -float("inf"), device=logits.device - ) + filtered_probs = torch.full((batch_size, p_filter + 5), 0.0, device=logits.device) sum_excluded_probs = torch.zeros( (batch_size,), device=logits.device, dtype=torch.float32 ) @@ -949,7 +939,7 @@ def apply_top_k_top_p_filtered( min_larger_p_fil_pivot = torch.zeros( (batch_size,), device=logits.device, dtype=torch.float32 ) - current_num_duplicates = torch.zeros( + num_duplicates_removed = torch.zeros( (batch_size,), device=logits.device, dtype=torch.uint32 ) do_deduplicate = torch.zeros( @@ -963,7 +953,7 @@ def apply_top_k_top_p_filtered( top_k_top_p_filter[(NUM_PROGRAMS,)]( do_deduplicate, - current_num_duplicates, + num_duplicates_removed, num_duplicates, min_larger_p_fil_pivot, filtered_logits_no_top_k, @@ -988,7 +978,7 @@ def apply_top_k_top_p_filtered( print(f"p {p}") print(f"do_deduplicate {do_deduplicate}") - print(f"current_num_duplicates {current_num_duplicates}") + print(f"num_duplicates_removed {num_duplicates_removed}") print(f"num_duplicates {num_duplicates}") print(f"min_larger_p_fil_pivot {min_larger_p_fil_pivot}") print(f"num_filtered {num_filtered}") From cd98ab906875457efc99cc1490041f2ec0cfe2ef Mon Sep 17 00:00:00 2001 From: js_park Date: Tue, 18 Nov 2025 15:30:36 -0800 Subject: [PATCH 71/99] Original pytorch implemntation apply softmax after sorting, which produce different probabilities to applying softmax before sorting, due to floating point non-associativity. Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 55 +++++++++++++++---------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index bd98d24ce9ea..ad5a210b3e0b 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -208,13 +208,17 @@ def apply_top_k_top_p( if p is not None: # Apply top-p. + # Note: Running softmax on "logits_sort" produces different probability + # values compared to running softmax on the original unsorted logits as the + # non-associativity of floating-points yields different sum(exp(logits)). probs_sort = logits_sort.softmax(dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + print(f"original probs_sum {probs_sum[:, -100:]}") top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) # at least one top_p_mask[:, -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) - + print(f"original logits_sort {logits_sort[:, -100:]}") # Re-sort the probabilities. logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) return logits @@ -904,7 +908,14 @@ def apply_top_k_top_p_filtered( # If k is too large, speedup is not significant as the filtered set is large. max_k = k.max().item() if k is not None else 0 - if max_k > vocab_size / 10: + # Our softmax result is different from the original PyTorch top-p implementation , + # as it runs softmax after a sort which produces different sum(exp(logits)) + # compared to our softmax result which runs softmax on the original unsorted logits. + # If p is too large, the top-p cutoff falls in the tail section of the distribution, + # which consists of very small probabilities which has larger relative errors + # compared to the sorted PyTorch top-p probabilities. As such, we fallback to + # the original PyTorch top-p implementation for accuracy when p is too large. + if max_k > vocab_size / 10 or (k is None and p.max().item() > 0.97): return apply_top_k_top_p(logits, k, p) BLOCK_SIZE = 8192 @@ -916,7 +927,6 @@ def apply_top_k_top_p_filtered( (NUM_PROGRAMS, vocab_size), device=logits.device, dtype=torch.float32 ) p_filter = ( - # vocab_size min(int(max_k * 1.5), vocab_size - 1) if k is not None else int(vocab_size / 32) ) filtered_logits = torch.full( @@ -976,18 +986,19 @@ def apply_top_k_top_p_filtered( num_stages=NUM_STAGES, ) - print(f"p {p}") - print(f"do_deduplicate {do_deduplicate}") - print(f"num_duplicates_removed {num_duplicates_removed}") - print(f"num_duplicates {num_duplicates}") - print(f"min_larger_p_fil_pivot {min_larger_p_fil_pivot}") - print(f"num_filtered {num_filtered}") - print(f"pfil_pivot {pfil_pivot}") - print(f"Filtered logits no top k {filtered_logits_no_top_k}") - print(f"Filtered logits {filtered_logits}") - print(f"Filtered indices {filtered_indices}") - print(f"Filtered probs {filtered_probs}") - print(f"Sum excluded probs {sum_excluded_probs}") + # print(f"p {p}") + # print(f"p_filter {p_filter}") + # print(f"do_deduplicate {do_deduplicate}") + # print(f"num_duplicates_removed {num_duplicates_removed}") + # print(f"num_duplicates {num_duplicates}") + # print(f"min_larger_p_fil_pivot {min_larger_p_fil_pivot}") + # print(f"num_filtered {num_filtered}") + # print(f"pfil_pivot {pfil_pivot}") + # print(f"Filtered logits no top k {filtered_logits_no_top_k}") + # print(f"Filtered logits {filtered_logits}") + # print(f"Filtered indices {filtered_indices}") + # print(f"Filtered probs {filtered_probs}") + # print(f"Sum excluded probs {sum_excluded_probs}") filtered_logits = filtered_logits[:, :p_filter] filtered_indices = filtered_indices[:, :p_filter] @@ -1003,20 +1014,22 @@ def apply_top_k_top_p_filtered( logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) sorted_probs = torch.gather(filtered_probs, -1, sort_indices) - print("logits_sort", logits_sort) - print("sorted_probs", sorted_probs) - print("logits_sort_indices", logits_sort_indices) + torch.set_printoptions(threshold=float("inf")) + print("logits_sort", logits_sort[:, -100:]) + print("sorted_probs", sorted_probs[:, -100:]) + print("logits_sort_indices", logits_sort_indices[:, -100:]) + torch.set_printoptions(threshold=None) sorted_probs[:, 0] = sorted_probs[:, 0] + sum_excluded_probs probs_sum = torch.cumsum(sorted_probs, dim=-1) - print("probs_sum", probs_sum) + print("probs_sum", probs_sum[:, -100:]) top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=-1)) print("threashold", 1 - p.unsqueeze(dim=-1)) top_p_mask[:, -1] = False - print("top_p_mask", top_p_mask) + print("top_p_mask", top_p_mask[:, -100:]) logits_sort.masked_fill_(top_p_mask, -float("inf")) - print("logits_sort_masked", logits_sort) + print("logits_sort_masked", logits_sort[:, -100:]) logits.fill_(-float("inf")) logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) From b72e2076ae06feddfcc49a86403632a344e6ebe6 Mon Sep 17 00:00:00 2001 From: js_park Date: Tue, 18 Nov 2025 16:01:07 -0800 Subject: [PATCH 72/99] Helper scripts Signed-off-by: js_park --- compare.py | 321 ++++++++++++++++++++++++ graph.py | 232 +++++++++++++++++ vllm/v1/sample/ops/topk_topp_sampler.py | 105 ++------ 3 files changed, 567 insertions(+), 91 deletions(-) create mode 100644 compare.py create mode 100644 graph.py diff --git a/compare.py b/compare.py new file mode 100644 index 000000000000..b9d1e694c957 --- /dev/null +++ b/compare.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from datetime import datetime +from itertools import product + +import regex as re +import torch + +from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, + apply_top_k_top_p_triton, +) + + +def g_str(s): + return "\033[32m" + s + "\033[0m" + + +def r_str(s): + return "\033[31m" + s + "\033[0m" + + +def y_str(s): + return "\033[33m" + s + "\033[0m" + + +def b_str(s): + return "\033[34m" + s + "\033[0m" + + +def print_to_log(s, log_file): + print(s) + # Remove the color codes + s = re.sub(r"\033\[[0-9;]*m", "", s) + with open(log_file, "a") as f: + f.write(s + "\n") + + +def test_accuracy(logits, k, p, func_list, log_file): + input_logit_list = [logits.clone().detach() for i in range(len(func_list))] + original_logits = func_list[0](input_logit_list[0], k, p) + output_correct_list = [] + for i in range(1, len(func_list)): + output_logits = func_list[i](input_logit_list[i], k, p) + + torch.cuda.synchronize() + is_correct = True + # original_logits_bin = original_logits.view(torch.int32) + # output_logits_bin = output_logits.view(torch.int32) + # is_correct = torch.all(original_logits_bin == output_logits_bin) + # is_correct = is_correct and torch.allclose( + # output_logits, original_logits + # ) + output_logits_sorted = torch.sort(output_logits, descending=True).values + original_logits_sorted = torch.sort(original_logits, descending=True).values + is_correct = is_correct and torch.allclose( + output_logits_sorted, original_logits_sorted + ) + output_correct_list.append(is_correct) + func_name = func_list[i].__name__ + + if not is_correct: + print_to_log( + r_str("Error: logits are not close on " + f"{func_name}"), + log_file, + ) + + # Check for NaN values first + output_has_nan = torch.isnan(output_logits).any().item() + original_has_nan = torch.isnan(original_logits).any().item() + output_nan_count = torch.isnan(output_logits).sum().item() + original_nan_count = torch.isnan(original_logits).sum().item() + + print_to_log( + "NaN check:\n" + + f" output_logits has NaN: {output_has_nan} (count: {output_nan_count})\n" + + f" original_logits has NaN: {original_has_nan} (count: {original_nan_count})\n" + + " Note: torch.allclose returns False if either tensor contains NaN (unless equal_nan=True)", + log_file, + ) + + if output_has_nan or original_has_nan: + # Show where NaN values are + if output_has_nan: + output_nan_positions = torch.where(torch.isnan(output_logits)) + print_to_log( + f" output_logits NaN positions (first 10): " + f"{list(zip(output_nan_positions[0][:10].tolist(), output_nan_positions[1][:10].tolist()))}", + log_file, + ) + if original_has_nan: + original_nan_positions = torch.where(torch.isnan(original_logits)) + print_to_log( + f" original_logits NaN positions (first 10): " + f"{list(zip(original_nan_positions[0][:10].tolist(), original_nan_positions[1][:10].tolist()))}", + log_file, + ) + + error = torch.abs(output_logits - original_logits) + # Handle NaN in error computation + error_has_nan = torch.isnan(error).any().item() + if error_has_nan: + error_nan_count = torch.isnan(error).sum().item() + print_to_log( + f" error tensor has NaN: True (count: {error_nan_count})", + log_file, + ) + # Use masked operations for NaN handling (compatible with all PyTorch versions) + valid_error = error[~torch.isnan(error)] + if valid_error.numel() > 0: + max_error = torch.max(valid_error).item() + mean_error = torch.mean(valid_error).item() + else: + max_error = float("nan") + mean_error = float("nan") + else: + max_error = torch.max(error).item() + mean_error = torch.mean(error).item() + + # Use the same tolerance as torch.allclose (rtol=1e-05, atol=1e-08) + atol = 1e-08 + rtol = 1e-05 + # torch.allclose checks: |input - other| <= atol + rtol * |other| + # Exclude NaN from tolerance check + valid_mask = ~torch.isnan(original_logits) & ~torch.isnan(output_logits) + tolerance = atol + rtol * torch.abs(original_logits) + error_mask = (error > tolerance) & valid_mask + + print_to_log( + f"Max absolute error: {max_error:.2e}\n" + + f"Mean absolute error: {mean_error:.2e}\n" + + f"torch.allclose tolerance: rtol={rtol}, atol={atol}", + log_file, + ) + + error_rows = torch.where(error_mask)[0] + error_rows = torch.unique(error_rows) + num_error_rows = error_rows.shape[0] + error_cols = torch.where(error_mask)[1] + error_cols = torch.unique(error_cols) + num_error_cols = error_cols.shape[0] + print_to_log( + f"num_error_rows: {num_error_rows} - {error_rows}\n" + + f"num_error_cols: {num_error_cols} - {error_cols}", + log_file, + ) + + if num_error_rows > 0: + row_to_show = 5 if num_error_rows > 5 else num_error_rows + logits_to_show = torch.sort( + output_logits[error_rows], descending=True + ).values + + logits_to_show = logits_to_show[:row_to_show, :50] + print_to_log(f"logits: {logits_to_show}", log_file) + original_logits_to_show = torch.sort( + original_logits[error_rows], descending=True + ).values + original_logits_to_show = original_logits_to_show[:row_to_show, :50] + print_to_log(f"original_logits: {original_logits_to_show}", log_file) + error_to_show = error[error_rows][:row_to_show, :50] + print_to_log(f"error (abs diff): {error_to_show}", log_file) + else: + # If no errors found with the mask, show the largest errors anyway + print_to_log( + "No errors found with tolerance mask, showing top errors:", log_file + ) + # Handle NaN in topk - replace NaN with -inf so they're not selected + error_for_topk = error.clone() + error_for_topk[torch.isnan(error_for_topk)] = float("-inf") + top_errors, top_indices = torch.topk( + error_for_topk.flatten(), min(20, error.numel()) + ) + print_to_log(f"Top 20 absolute errors: {top_errors}", log_file) + for idx, err_val in zip(top_indices, top_errors): + row_idx = idx.item() // error.shape[1] + col_idx = idx.item() % error.shape[1] + output_val = output_logits[row_idx, col_idx].item() + original_val = original_logits[row_idx, col_idx].item() + err_val_item = err_val.item() + # Check if values are NaN + output_str = ( + f"{output_val:.10f}" + if not torch.isnan(output_logits[row_idx, col_idx]) + else "NaN" + ) + original_str = ( + f"{original_val:.10f}" + if not torch.isnan(original_logits[row_idx, col_idx]) + else "NaN" + ) + error_str = ( + f"{err_val_item:.2e}" + if not torch.isnan(error[row_idx, col_idx]) + else "NaN" + ) + print_to_log( + f" Position [{row_idx}, {col_idx}]: " + f"output={output_str}, " + f"original={original_str}, " + f"error={error_str}", + log_file, + ) + # raise ValueError("Logits are not close") + return output_correct_list + + +def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): + # We must clone the logits for each run to avoid modifying the original + warmup_tensor = [logits.clone().detach() for _ in range(num_warmup)] + for _ in range(num_warmup): + test_func(warmup_tensor[_], k, p) + torch.cuda.synchronize() + + input_logits = [logits.clone().detach() for _ in range(num_runs)] + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + for _ in range(num_runs): + input_logits[_] = test_func(input_logits[_], k, p) + end.record() + torch.cuda.synchronize() + time_taken = start.elapsed_time(end) / num_runs + + return time_taken + + +if __name__ == "__main__": + date_str = datetime.now().strftime("%Y%m%d_%H%M%S") + + batch_size_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + vocab_size_list = [32768, 65536, 102400, 128256] + p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.95, 0.99] + k_list = [None, "RAND", 5, 20, 50, 200, 500, 3000] + func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] + + log_file = f"triton_topk_topp_test_{date_str}.log" + csv_file = f"triton_topk_topp_test_{date_str}.csv" + + print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) + print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) + print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) + print_to_log(y_str("p_list:") + f"{p_list}", log_file) + print_to_log(y_str("k_list:") + f"{k_list}", log_file) + + print_to_log(y_str("log_file:") + f"{log_file}", log_file) + print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) + + with open(csv_file, "w") as f: + f.write( + "dist_generator,batch_size,vocab_size,p,k,triton_correct," + "torch_time_taken,triton_time_taken,triton_speedup\n" + ) + + for batch_size, vocab_size, p, k in product( + batch_size_list, vocab_size_list, p_list, k_list + ): + if p is None and k is None: + continue + + logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 + top_5_logits = torch.topk(logits_randn, 5, dim=-1).values + + logits_list = [("RANDN", logits_randn)] + + if p == "RAND": + p_tensor = torch.rand((batch_size,), device="cuda") * 0.98 + 0.01 + elif p is not None: + p_tensor = torch.full((batch_size,), p, device="cuda") + else: + p_tensor = None + + if k == "RAND": + k_tensor = torch.randint( + 1, int(vocab_size / 4) - 1, (batch_size,), device="cuda" + ) + elif k is not None: + k_tensor = torch.full((batch_size,), k, device="cuda") + else: + k_tensor = None + + for dist_generator, logits in logits_list: + print_to_log(y_str("--------------------------------"), log_file) + print_to_log( + g_str("Testing ") + + f"{dist_generator}" + + y_str(" with batch_size: ") + + f"{batch_size}" + + y_str(" vocab_size: ") + + f"{vocab_size}" + + y_str(" p: ") + + f"{p}" + + y_str(" k: ") + + f"{k}", + log_file, + ) + correct_list = test_accuracy( + logits, k_tensor, p_tensor, func_list, log_file + ) + time_list = [] + for func in func_list: + time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) + time_list.append(time_taken) + print_to_log(b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) + print_to_log(b_str("triton_time_taken: ") + f"{time_list[1]}", log_file) + print_to_log( + g_str("test Speedup over Torch: ") + + f"{time_list[0] / time_list[1]:.8f}x", + log_file, + ) + with open(csv_file, "a") as f: + p_str = "NONE" if p is None else str(p) + k_str = "NONE" if k is None else str(k) + f.write( + f"{dist_generator},{batch_size},{vocab_size},{p_str},{k_str}," + f"{correct_list[0]},{time_list[0]},{time_list[1]}," + f"{time_list[0] / time_list[1]:.8f}\n" + ) + print_to_log(y_str("--------------------------------\n"), log_file) diff --git a/graph.py b/graph.py new file mode 100644 index 000000000000..3c1039293fbf --- /dev/null +++ b/graph.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import matplotlib.pyplot as plt +import pandas as pd + +input_path = "triton_topk_topp_test_wo_fallback.csv" +output_name = "speedup_analysis_wo_fallback.png" + + +def load_and_parse_data(csv_file): + """Load CSV data and parse it into a structured format.""" + df = pd.read_csv(csv_file, dtype={"p": str, "k": str}) + print(df.head()) + print(df.columns) + print(df.info()) + print(df.describe()) + print(df.isnull().sum()) + print(df.duplicated().sum()) + print(df.shape) + print(df.head()) + return df + + +def get_filtered_data(df, vocab_size, p_val, k_val): + """Filter data for specific vocab_size, p, and k values.""" + # Handle None values properly + if p_val is None: + p_condition = df["p"] == "NONE" + else: + p_condition = df["p"] == str(p_val) + + if k_val is None: + k_condition = df["k"] == "NONE" + else: + k_condition = df["k"] == str(k_val) + + filtered_df = df[ + (df["vocab_size"] == vocab_size) & p_condition & k_condition + ].copy() + + return filtered_df.sort_values("batch_size") + + +def create_speedup_plots(column_configs, vocab_sizes): + """Create 4x4 grid of speedup vs batch size plots.""" + # Load data + csv_file = input_path + df = load_and_parse_data(csv_file) + + # We'll calculate y-axis limits per subplot now + + # Create figure with subplots + fig, axes = plt.subplots(4, 4, figsize=(20, 16)) + fig.suptitle("Speedup vs Batch Size", fontsize=20, fontweight="bold") + + # Plot each combination + for row, vocab_size in enumerate(vocab_sizes): + for col, config in enumerate(column_configs): + ax = axes[row, col] + + # Get filtered data + data = get_filtered_data(df, vocab_size, config["p"], config["k"]) + + if not data.empty: + # Calculate y-axis limit for this specific subplot + local_max_speedup = data["triton_speedup"].max() + local_y_max = local_max_speedup * 1.1 if local_max_speedup > 0 else 10.0 + + # Plot speedup vs batch size + ax.plot( + data["batch_size"], + data["triton_speedup"], + "bo-", + linewidth=2, + markersize=6, + ) + ax.set_xscale("log", base=2) + ax.set_ylim(0.0, local_y_max) # Set y-axis range from 0 to local max + ax.grid(True, alpha=0.3) + + # Add horizontal line at speedup=1 + ax.axhline( + y=1, + color="red", + linestyle="--", + linewidth=2, + alpha=0.7, + label="Speedup=1", + ) + + # Set labels and title + if row == 3: # Bottom row + ax.set_xlabel("Batch Size", fontsize=12) + if col == 0: # Left column + ax.set_ylabel("Speedup", fontsize=12) + + # Set title for top row + if row == 0: + ax.set_title(config["title"], fontsize=14, fontweight="bold") + + # Add vocab size label on the left + if col == 0: + vocab_size_str = f"Vocab Size {vocab_size}" + + ax.text( + -0.2, + 0.5, + vocab_size_str, + transform=ax.transAxes, + fontsize=14, + fontweight="bold", + ha="center", + va="center", + rotation=90, + ) + + # Set reasonable axis limits + if len(data) > 0: + batch_sizes = data["batch_size"].values + + ax.set_xlim(batch_sizes.min() * 0.8, batch_sizes.max() * 1.2) + # Y-axis is already set to 0-10 above + + # Format x-axis ticks + ax.set_xticks([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + ax.set_xticklabels([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, ""]) + + # Add legend only to the first subplot + if row == 0 and col == 0: + ax.legend(loc="upper left") + + else: + # No data available - use default y-axis range + default_y_max = 10.0 + ax.text( + 0.5, + 0.5, + "No Data\nAvailable", + transform=ax.transAxes, + fontsize=12, + ha="center", + va="center", + color="red", + ) + ax.set_xlim(1, 2048) + ax.set_ylim(0.0, default_y_max) # Set y-axis range from 0 to default + ax.set_xscale("log", base=2) + ax.grid(True, alpha=0.3) + + # Add horizontal line at speedup=1 + ax.axhline(y=1, color="red", linestyle="--", linewidth=2, alpha=0.7) + + if row == 3: # Bottom row + ax.set_xlabel("Batch Size", fontsize=12) + if col == 0: # Left column + ax.set_ylabel("triton_speedup", fontsize=12) + if row == 0: + ax.set_title(config["title"], fontsize=12, fontweight="bold") + if col == 0: + ax.text( + -0.2, + 0.5, + f"Vocab Size {vocab_size}", + transform=ax.transAxes, + fontsize=14, + fontweight="bold", + ha="center", + va="center", + rotation=90, + ) + + # Adjust layout + plt.tight_layout() + plt.subplots_adjust(top=0.93, left=0.08) + + # Save the plot + output_file = f"./{output_name}" + plt.savefig(output_file, dpi=300, bbox_inches="tight") + print(f"Speedup analysis plot saved to: {output_file}") + + # Show the plot + plt.show() + + return fig + + +def print_data_summary(column_configs, vocab_sizes): + """Print a summary of the available data.""" + csv_file = input_path + df = load_and_parse_data(csv_file) + + print("Data Summary:") + print(f"Total rows: {len(df)}") + print(f"Unique batch sizes: {sorted(df['batch_size'].unique())}") + print(f"Unique vocab sizes: {sorted(df['vocab_size'].unique())}") + print(f"Unique p values: {sorted([p for p in df['p'].unique() if p != 'nan'])}") + print(f"Unique k values: {sorted([k for k in df['k'].unique() if k != 'nan'])}") + print() + + print("Data availability matrix:") + print("Rows: Vocab sizes, Columns: Parameter combinations") + print("Values: Number of data points available") + print() + + header = f"{'Vocab Size':<12}" + for config in column_configs: + header += f"{config['title']:<15}" + print(header) + print("-" * len(header)) + + for vocab_size in vocab_sizes: + row = f"{vocab_size:<12}" + for config in column_configs: + data = get_filtered_data(df, vocab_size, config["p"], config["k"]) + row += f"{len(data):<15}" + print(row) + + +if __name__ == "__main__": + column_configs = [ + {"p": None, "k": 50, "title": "P=None, K=50"}, + {"p": 0.9, "k": None, "title": "P=0.9, K=None"}, + {"p": 0.9, "k": 50, "title": "P=0.9, K=50"}, + {"p": "RAND", "k": 3000, "title": "P=RAND, K=3000"}, + ] + vocab_sizes = [16384, 65536, 102400, 128256] + # Print data summary first + print_data_summary(column_configs, vocab_sizes) + print("\n" + "=" * 80 + "\n") + + # Create the plots + create_speedup_plots(column_configs, vocab_sizes) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index ad5a210b3e0b..6ce6466d89af 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -213,12 +213,10 @@ def apply_top_k_top_p( # non-associativity of floating-points yields different sum(exp(logits)). probs_sort = logits_sort.softmax(dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) - print(f"original probs_sum {probs_sum[:, -100:]}") top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) # at least one top_p_mask[:, -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) - print(f"original logits_sort {logits_sort[:, -100:]}") # Re-sort the probabilities. logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) return logits @@ -361,7 +359,7 @@ def apply_top_k_top_p_triton( return logits if p is None and k is not None: return apply_top_k_only_triton(logits, k) - # Fallback to torch for small batch sizes for top-p + # Fallback to torch for small batch sizes or small vocab sizes for top-p if logits.shape[0] < 16 or logits.shape[1] < 32768: return apply_top_k_top_p(logits, k, p) return apply_top_k_top_p_filtered(logits, k, p) @@ -538,12 +536,6 @@ def apply_top_k_only_triton( @triton.jit def top_k_top_p_filter( - DO_DEDUPLICATE, - NUM_DUPLICATES_REMOVED, - NUM_DUPLICATES, - MIN_LARGER_P_FIL_PIVOT, - FILTERED_LOGITS_NO_TOP_K, - PFIL_PIVOT, LOGITS, DO_TOP_K, K, @@ -554,7 +546,6 @@ def top_k_top_p_filter( FILTERED_LOGITS, FILTERED_INDICES, FILTERED_PROBS, - NUM_FILTERED, PERCENTILE_TO_STD_TABLE, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -843,9 +834,9 @@ def top_k_top_p_filter( write_pos = tl.zeros((), dtype=tl.int32) sum_excluded_probs = tl.zeros((), dtype=tl.float32) num_duplicates_removed = tl.zeros((), dtype=tl.uint32) - FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * (P_FIL + 5) - FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * (P_FIL + 5) - FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * (P_FIL + 5) + FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * P_FIL + FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * P_FIL + FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * P_FIL for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE @@ -870,13 +861,6 @@ def top_k_top_p_filter( f_mask = keep_mask write_idx = tl.where(f_mask, cpos, P_FIL) - FILTERED_LOGITS_NO_TOP_K_ROW = FILTERED_LOGITS_NO_TOP_K + row_id * ( - P_FIL + 5 - ) - tl.store( - FILTERED_LOGITS_NO_TOP_K_ROW + write_idx, logits_blk, mask=keep_mask - ) - top_k_mask = (logits_blk > k_pivot) & mask_n logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) @@ -887,13 +871,7 @@ def top_k_top_p_filter( sum_excluded_probs += tl.sum(probs_blk * (keep_mask & (~f_mask) & mask_n)) write_pos += tl.sum(f_mask, dtype=tl.int32) - tl.store(PFIL_PIVOT + row_id, p_fil_pivot) tl.store(SUM_EXCLUDED_PROBS + row_id, sum_excluded_probs) - tl.store(NUM_FILTERED + row_id, write_pos) - tl.store(NUM_DUPLICATES_REMOVED + row_id, num_duplicates_removed) - tl.store(NUM_DUPLICATES + row_id, num_min_larger_p_fil_pivot) - tl.store(MIN_LARGER_P_FIL_PIVOT + row_id, min_larger_p_fil_pivot) - tl.store(DO_DEDUPLICATE + row_id, do_deduplicate) def apply_top_k_top_p_filtered( @@ -908,14 +886,16 @@ def apply_top_k_top_p_filtered( # If k is too large, speedup is not significant as the filtered set is large. max_k = k.max().item() if k is not None else 0 - # Our softmax result is different from the original PyTorch top-p implementation , - # as it runs softmax after a sort which produces different sum(exp(logits)) - # compared to our softmax result which runs softmax on the original unsorted logits. + + # Our softmax result is different from the original PyTorch top-p implementation + # which runs softmax after a sort compared to our softmax result which runs + # softmax on the original unsorted logits, yielding different sum(exp(logits)) + # values due to the non-associativity of floating-points. # If p is too large, the top-p cutoff falls in the tail section of the distribution, # which consists of very small probabilities which has larger relative errors - # compared to the sorted PyTorch top-p probabilities. As such, we fallback to + # compared to the original PyTorch top-p probabilities. As such, we fallback to # the original PyTorch top-p implementation for accuracy when p is too large. - if max_k > vocab_size / 10 or (k is None and p.max().item() > 0.97): + if max_k > vocab_size / 4 or (k is None and p.max().item() > 0.995): return apply_top_k_top_p(logits, k, p) BLOCK_SIZE = 8192 @@ -930,44 +910,20 @@ def apply_top_k_top_p_filtered( min(int(max_k * 1.5), vocab_size - 1) if k is not None else int(vocab_size / 32) ) filtered_logits = torch.full( - (batch_size, p_filter + 5), -float("inf"), device=logits.device + (batch_size, p_filter), -float("inf"), device=logits.device ) - filtered_logits_no_top_k = torch.full( - (batch_size, p_filter + 5), -float("inf"), device=logits.device - ) - filtered_indices = torch.full( - (batch_size, p_filter + 5), p_filter, dtype=torch.int64, device=logits.device + (batch_size, p_filter), p_filter, dtype=torch.int64, device=logits.device ) - filtered_probs = torch.full((batch_size, p_filter + 5), 0.0, device=logits.device) + filtered_probs = torch.full((batch_size, p_filter), 0.0, device=logits.device) sum_excluded_probs = torch.zeros( (batch_size,), device=logits.device, dtype=torch.float32 ) - num_duplicates = torch.zeros( - (batch_size,), device=logits.device, dtype=torch.uint32 - ) - min_larger_p_fil_pivot = torch.zeros( - (batch_size,), device=logits.device, dtype=torch.float32 - ) - num_duplicates_removed = torch.zeros( - (batch_size,), device=logits.device, dtype=torch.uint32 - ) - do_deduplicate = torch.zeros( - (batch_size,), device=logits.device, dtype=torch.uint32 - ) - num_filtered = torch.zeros((batch_size,), device=logits.device, dtype=torch.int32) - pfil_pivot = torch.zeros((batch_size,), device=logits.device, dtype=torch.float32) PERCENTILE_TO_STD_TABLE = torch.tensor( _PERCENTILE_TO_STD_TABLE, device=logits.device ) top_k_top_p_filter[(NUM_PROGRAMS,)]( - do_deduplicate, - num_duplicates_removed, - num_duplicates, - min_larger_p_fil_pivot, - filtered_logits_no_top_k, - pfil_pivot, logits, (k is not None), k if k is not None else filtered_indices, @@ -978,7 +934,6 @@ def apply_top_k_top_p_filtered( filtered_logits, filtered_indices, filtered_probs, - num_filtered, PERCENTILE_TO_STD_TABLE, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, @@ -986,27 +941,6 @@ def apply_top_k_top_p_filtered( num_stages=NUM_STAGES, ) - # print(f"p {p}") - # print(f"p_filter {p_filter}") - # print(f"do_deduplicate {do_deduplicate}") - # print(f"num_duplicates_removed {num_duplicates_removed}") - # print(f"num_duplicates {num_duplicates}") - # print(f"min_larger_p_fil_pivot {min_larger_p_fil_pivot}") - # print(f"num_filtered {num_filtered}") - # print(f"pfil_pivot {pfil_pivot}") - # print(f"Filtered logits no top k {filtered_logits_no_top_k}") - # print(f"Filtered logits {filtered_logits}") - # print(f"Filtered indices {filtered_indices}") - # print(f"Filtered probs {filtered_probs}") - # print(f"Sum excluded probs {sum_excluded_probs}") - - filtered_logits = filtered_logits[:, :p_filter] - filtered_indices = filtered_indices[:, :p_filter] - filtered_probs = filtered_probs[:, :p_filter] - - if torch.any(num_filtered != p_filter): - print(f"num_filtered != p_filter: {num_filtered} != {p_filter}") - if torch.any(sum_excluded_probs >= p): return apply_top_k_top_p(logits, k, p) @@ -1014,23 +948,12 @@ def apply_top_k_top_p_filtered( logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) sorted_probs = torch.gather(filtered_probs, -1, sort_indices) - torch.set_printoptions(threshold=float("inf")) - print("logits_sort", logits_sort[:, -100:]) - print("sorted_probs", sorted_probs[:, -100:]) - print("logits_sort_indices", logits_sort_indices[:, -100:]) - torch.set_printoptions(threshold=None) - sorted_probs[:, 0] = sorted_probs[:, 0] + sum_excluded_probs probs_sum = torch.cumsum(sorted_probs, dim=-1) - print("probs_sum", probs_sum[:, -100:]) top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=-1)) - print("threashold", 1 - p.unsqueeze(dim=-1)) top_p_mask[:, -1] = False - print("top_p_mask", top_p_mask[:, -100:]) logits_sort.masked_fill_(top_p_mask, -float("inf")) - print("logits_sort_masked", logits_sort[:, -100:]) - logits.fill_(-float("inf")) logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) return logits From b1152c156c6978309b8736ff10e5d61e87b78a7e Mon Sep 17 00:00:00 2001 From: js_park Date: Tue, 18 Nov 2025 16:02:44 -0800 Subject: [PATCH 73/99] Helper scripts removed Signed-off-by: js_park --- compare.py | 321 ----------------------------------------------------- graph.py | 232 -------------------------------------- 2 files changed, 553 deletions(-) delete mode 100644 compare.py delete mode 100644 graph.py diff --git a/compare.py b/compare.py deleted file mode 100644 index b9d1e694c957..000000000000 --- a/compare.py +++ /dev/null @@ -1,321 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from datetime import datetime -from itertools import product - -import regex as re -import torch - -from vllm.v1.sample.ops.topk_topp_sampler import ( - apply_top_k_top_p, - apply_top_k_top_p_triton, -) - - -def g_str(s): - return "\033[32m" + s + "\033[0m" - - -def r_str(s): - return "\033[31m" + s + "\033[0m" - - -def y_str(s): - return "\033[33m" + s + "\033[0m" - - -def b_str(s): - return "\033[34m" + s + "\033[0m" - - -def print_to_log(s, log_file): - print(s) - # Remove the color codes - s = re.sub(r"\033\[[0-9;]*m", "", s) - with open(log_file, "a") as f: - f.write(s + "\n") - - -def test_accuracy(logits, k, p, func_list, log_file): - input_logit_list = [logits.clone().detach() for i in range(len(func_list))] - original_logits = func_list[0](input_logit_list[0], k, p) - output_correct_list = [] - for i in range(1, len(func_list)): - output_logits = func_list[i](input_logit_list[i], k, p) - - torch.cuda.synchronize() - is_correct = True - # original_logits_bin = original_logits.view(torch.int32) - # output_logits_bin = output_logits.view(torch.int32) - # is_correct = torch.all(original_logits_bin == output_logits_bin) - # is_correct = is_correct and torch.allclose( - # output_logits, original_logits - # ) - output_logits_sorted = torch.sort(output_logits, descending=True).values - original_logits_sorted = torch.sort(original_logits, descending=True).values - is_correct = is_correct and torch.allclose( - output_logits_sorted, original_logits_sorted - ) - output_correct_list.append(is_correct) - func_name = func_list[i].__name__ - - if not is_correct: - print_to_log( - r_str("Error: logits are not close on " + f"{func_name}"), - log_file, - ) - - # Check for NaN values first - output_has_nan = torch.isnan(output_logits).any().item() - original_has_nan = torch.isnan(original_logits).any().item() - output_nan_count = torch.isnan(output_logits).sum().item() - original_nan_count = torch.isnan(original_logits).sum().item() - - print_to_log( - "NaN check:\n" - + f" output_logits has NaN: {output_has_nan} (count: {output_nan_count})\n" - + f" original_logits has NaN: {original_has_nan} (count: {original_nan_count})\n" - + " Note: torch.allclose returns False if either tensor contains NaN (unless equal_nan=True)", - log_file, - ) - - if output_has_nan or original_has_nan: - # Show where NaN values are - if output_has_nan: - output_nan_positions = torch.where(torch.isnan(output_logits)) - print_to_log( - f" output_logits NaN positions (first 10): " - f"{list(zip(output_nan_positions[0][:10].tolist(), output_nan_positions[1][:10].tolist()))}", - log_file, - ) - if original_has_nan: - original_nan_positions = torch.where(torch.isnan(original_logits)) - print_to_log( - f" original_logits NaN positions (first 10): " - f"{list(zip(original_nan_positions[0][:10].tolist(), original_nan_positions[1][:10].tolist()))}", - log_file, - ) - - error = torch.abs(output_logits - original_logits) - # Handle NaN in error computation - error_has_nan = torch.isnan(error).any().item() - if error_has_nan: - error_nan_count = torch.isnan(error).sum().item() - print_to_log( - f" error tensor has NaN: True (count: {error_nan_count})", - log_file, - ) - # Use masked operations for NaN handling (compatible with all PyTorch versions) - valid_error = error[~torch.isnan(error)] - if valid_error.numel() > 0: - max_error = torch.max(valid_error).item() - mean_error = torch.mean(valid_error).item() - else: - max_error = float("nan") - mean_error = float("nan") - else: - max_error = torch.max(error).item() - mean_error = torch.mean(error).item() - - # Use the same tolerance as torch.allclose (rtol=1e-05, atol=1e-08) - atol = 1e-08 - rtol = 1e-05 - # torch.allclose checks: |input - other| <= atol + rtol * |other| - # Exclude NaN from tolerance check - valid_mask = ~torch.isnan(original_logits) & ~torch.isnan(output_logits) - tolerance = atol + rtol * torch.abs(original_logits) - error_mask = (error > tolerance) & valid_mask - - print_to_log( - f"Max absolute error: {max_error:.2e}\n" - + f"Mean absolute error: {mean_error:.2e}\n" - + f"torch.allclose tolerance: rtol={rtol}, atol={atol}", - log_file, - ) - - error_rows = torch.where(error_mask)[0] - error_rows = torch.unique(error_rows) - num_error_rows = error_rows.shape[0] - error_cols = torch.where(error_mask)[1] - error_cols = torch.unique(error_cols) - num_error_cols = error_cols.shape[0] - print_to_log( - f"num_error_rows: {num_error_rows} - {error_rows}\n" - + f"num_error_cols: {num_error_cols} - {error_cols}", - log_file, - ) - - if num_error_rows > 0: - row_to_show = 5 if num_error_rows > 5 else num_error_rows - logits_to_show = torch.sort( - output_logits[error_rows], descending=True - ).values - - logits_to_show = logits_to_show[:row_to_show, :50] - print_to_log(f"logits: {logits_to_show}", log_file) - original_logits_to_show = torch.sort( - original_logits[error_rows], descending=True - ).values - original_logits_to_show = original_logits_to_show[:row_to_show, :50] - print_to_log(f"original_logits: {original_logits_to_show}", log_file) - error_to_show = error[error_rows][:row_to_show, :50] - print_to_log(f"error (abs diff): {error_to_show}", log_file) - else: - # If no errors found with the mask, show the largest errors anyway - print_to_log( - "No errors found with tolerance mask, showing top errors:", log_file - ) - # Handle NaN in topk - replace NaN with -inf so they're not selected - error_for_topk = error.clone() - error_for_topk[torch.isnan(error_for_topk)] = float("-inf") - top_errors, top_indices = torch.topk( - error_for_topk.flatten(), min(20, error.numel()) - ) - print_to_log(f"Top 20 absolute errors: {top_errors}", log_file) - for idx, err_val in zip(top_indices, top_errors): - row_idx = idx.item() // error.shape[1] - col_idx = idx.item() % error.shape[1] - output_val = output_logits[row_idx, col_idx].item() - original_val = original_logits[row_idx, col_idx].item() - err_val_item = err_val.item() - # Check if values are NaN - output_str = ( - f"{output_val:.10f}" - if not torch.isnan(output_logits[row_idx, col_idx]) - else "NaN" - ) - original_str = ( - f"{original_val:.10f}" - if not torch.isnan(original_logits[row_idx, col_idx]) - else "NaN" - ) - error_str = ( - f"{err_val_item:.2e}" - if not torch.isnan(error[row_idx, col_idx]) - else "NaN" - ) - print_to_log( - f" Position [{row_idx}, {col_idx}]: " - f"output={output_str}, " - f"original={original_str}, " - f"error={error_str}", - log_file, - ) - # raise ValueError("Logits are not close") - return output_correct_list - - -def test_time(logits, k, p, test_func, num_runs=30, num_warmup=5): - # We must clone the logits for each run to avoid modifying the original - warmup_tensor = [logits.clone().detach() for _ in range(num_warmup)] - for _ in range(num_warmup): - test_func(warmup_tensor[_], k, p) - torch.cuda.synchronize() - - input_logits = [logits.clone().detach() for _ in range(num_runs)] - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - for _ in range(num_runs): - input_logits[_] = test_func(input_logits[_], k, p) - end.record() - torch.cuda.synchronize() - time_taken = start.elapsed_time(end) / num_runs - - return time_taken - - -if __name__ == "__main__": - date_str = datetime.now().strftime("%Y%m%d_%H%M%S") - - batch_size_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] - vocab_size_list = [32768, 65536, 102400, 128256] - p_list = [None, "RAND", 0.1, 0.4, 0.7, 0.9, 0.95, 0.99] - k_list = [None, "RAND", 5, 20, 50, 200, 500, 3000] - func_list = [apply_top_k_top_p, apply_top_k_top_p_triton] - - log_file = f"triton_topk_topp_test_{date_str}.log" - csv_file = f"triton_topk_topp_test_{date_str}.csv" - - print_to_log(y_str("Testing TopKTopPSampler with Triton"), log_file) - print_to_log(y_str("batch_size_list:") + f"{batch_size_list}", log_file) - print_to_log(y_str("vocab_size_list:") + f"{vocab_size_list}", log_file) - print_to_log(y_str("p_list:") + f"{p_list}", log_file) - print_to_log(y_str("k_list:") + f"{k_list}", log_file) - - print_to_log(y_str("log_file:") + f"{log_file}", log_file) - print_to_log(y_str("csv_file:") + f"{csv_file}", log_file) - - with open(csv_file, "w") as f: - f.write( - "dist_generator,batch_size,vocab_size,p,k,triton_correct," - "torch_time_taken,triton_time_taken,triton_speedup\n" - ) - - for batch_size, vocab_size, p, k in product( - batch_size_list, vocab_size_list, p_list, k_list - ): - if p is None and k is None: - continue - - logits_randn = torch.randn(batch_size, vocab_size, device="cuda") * 10 - top_5_logits = torch.topk(logits_randn, 5, dim=-1).values - - logits_list = [("RANDN", logits_randn)] - - if p == "RAND": - p_tensor = torch.rand((batch_size,), device="cuda") * 0.98 + 0.01 - elif p is not None: - p_tensor = torch.full((batch_size,), p, device="cuda") - else: - p_tensor = None - - if k == "RAND": - k_tensor = torch.randint( - 1, int(vocab_size / 4) - 1, (batch_size,), device="cuda" - ) - elif k is not None: - k_tensor = torch.full((batch_size,), k, device="cuda") - else: - k_tensor = None - - for dist_generator, logits in logits_list: - print_to_log(y_str("--------------------------------"), log_file) - print_to_log( - g_str("Testing ") - + f"{dist_generator}" - + y_str(" with batch_size: ") - + f"{batch_size}" - + y_str(" vocab_size: ") - + f"{vocab_size}" - + y_str(" p: ") - + f"{p}" - + y_str(" k: ") - + f"{k}", - log_file, - ) - correct_list = test_accuracy( - logits, k_tensor, p_tensor, func_list, log_file - ) - time_list = [] - for func in func_list: - time_taken = test_time(logits, k_tensor, p_tensor, test_func=func) - time_list.append(time_taken) - print_to_log(b_str("torch_time_taken: ") + f"{time_list[0]}", log_file) - print_to_log(b_str("triton_time_taken: ") + f"{time_list[1]}", log_file) - print_to_log( - g_str("test Speedup over Torch: ") - + f"{time_list[0] / time_list[1]:.8f}x", - log_file, - ) - with open(csv_file, "a") as f: - p_str = "NONE" if p is None else str(p) - k_str = "NONE" if k is None else str(k) - f.write( - f"{dist_generator},{batch_size},{vocab_size},{p_str},{k_str}," - f"{correct_list[0]},{time_list[0]},{time_list[1]}," - f"{time_list[0] / time_list[1]:.8f}\n" - ) - print_to_log(y_str("--------------------------------\n"), log_file) diff --git a/graph.py b/graph.py deleted file mode 100644 index 3c1039293fbf..000000000000 --- a/graph.py +++ /dev/null @@ -1,232 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import matplotlib.pyplot as plt -import pandas as pd - -input_path = "triton_topk_topp_test_wo_fallback.csv" -output_name = "speedup_analysis_wo_fallback.png" - - -def load_and_parse_data(csv_file): - """Load CSV data and parse it into a structured format.""" - df = pd.read_csv(csv_file, dtype={"p": str, "k": str}) - print(df.head()) - print(df.columns) - print(df.info()) - print(df.describe()) - print(df.isnull().sum()) - print(df.duplicated().sum()) - print(df.shape) - print(df.head()) - return df - - -def get_filtered_data(df, vocab_size, p_val, k_val): - """Filter data for specific vocab_size, p, and k values.""" - # Handle None values properly - if p_val is None: - p_condition = df["p"] == "NONE" - else: - p_condition = df["p"] == str(p_val) - - if k_val is None: - k_condition = df["k"] == "NONE" - else: - k_condition = df["k"] == str(k_val) - - filtered_df = df[ - (df["vocab_size"] == vocab_size) & p_condition & k_condition - ].copy() - - return filtered_df.sort_values("batch_size") - - -def create_speedup_plots(column_configs, vocab_sizes): - """Create 4x4 grid of speedup vs batch size plots.""" - # Load data - csv_file = input_path - df = load_and_parse_data(csv_file) - - # We'll calculate y-axis limits per subplot now - - # Create figure with subplots - fig, axes = plt.subplots(4, 4, figsize=(20, 16)) - fig.suptitle("Speedup vs Batch Size", fontsize=20, fontweight="bold") - - # Plot each combination - for row, vocab_size in enumerate(vocab_sizes): - for col, config in enumerate(column_configs): - ax = axes[row, col] - - # Get filtered data - data = get_filtered_data(df, vocab_size, config["p"], config["k"]) - - if not data.empty: - # Calculate y-axis limit for this specific subplot - local_max_speedup = data["triton_speedup"].max() - local_y_max = local_max_speedup * 1.1 if local_max_speedup > 0 else 10.0 - - # Plot speedup vs batch size - ax.plot( - data["batch_size"], - data["triton_speedup"], - "bo-", - linewidth=2, - markersize=6, - ) - ax.set_xscale("log", base=2) - ax.set_ylim(0.0, local_y_max) # Set y-axis range from 0 to local max - ax.grid(True, alpha=0.3) - - # Add horizontal line at speedup=1 - ax.axhline( - y=1, - color="red", - linestyle="--", - linewidth=2, - alpha=0.7, - label="Speedup=1", - ) - - # Set labels and title - if row == 3: # Bottom row - ax.set_xlabel("Batch Size", fontsize=12) - if col == 0: # Left column - ax.set_ylabel("Speedup", fontsize=12) - - # Set title for top row - if row == 0: - ax.set_title(config["title"], fontsize=14, fontweight="bold") - - # Add vocab size label on the left - if col == 0: - vocab_size_str = f"Vocab Size {vocab_size}" - - ax.text( - -0.2, - 0.5, - vocab_size_str, - transform=ax.transAxes, - fontsize=14, - fontweight="bold", - ha="center", - va="center", - rotation=90, - ) - - # Set reasonable axis limits - if len(data) > 0: - batch_sizes = data["batch_size"].values - - ax.set_xlim(batch_sizes.min() * 0.8, batch_sizes.max() * 1.2) - # Y-axis is already set to 0-10 above - - # Format x-axis ticks - ax.set_xticks([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) - ax.set_xticklabels([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, ""]) - - # Add legend only to the first subplot - if row == 0 and col == 0: - ax.legend(loc="upper left") - - else: - # No data available - use default y-axis range - default_y_max = 10.0 - ax.text( - 0.5, - 0.5, - "No Data\nAvailable", - transform=ax.transAxes, - fontsize=12, - ha="center", - va="center", - color="red", - ) - ax.set_xlim(1, 2048) - ax.set_ylim(0.0, default_y_max) # Set y-axis range from 0 to default - ax.set_xscale("log", base=2) - ax.grid(True, alpha=0.3) - - # Add horizontal line at speedup=1 - ax.axhline(y=1, color="red", linestyle="--", linewidth=2, alpha=0.7) - - if row == 3: # Bottom row - ax.set_xlabel("Batch Size", fontsize=12) - if col == 0: # Left column - ax.set_ylabel("triton_speedup", fontsize=12) - if row == 0: - ax.set_title(config["title"], fontsize=12, fontweight="bold") - if col == 0: - ax.text( - -0.2, - 0.5, - f"Vocab Size {vocab_size}", - transform=ax.transAxes, - fontsize=14, - fontweight="bold", - ha="center", - va="center", - rotation=90, - ) - - # Adjust layout - plt.tight_layout() - plt.subplots_adjust(top=0.93, left=0.08) - - # Save the plot - output_file = f"./{output_name}" - plt.savefig(output_file, dpi=300, bbox_inches="tight") - print(f"Speedup analysis plot saved to: {output_file}") - - # Show the plot - plt.show() - - return fig - - -def print_data_summary(column_configs, vocab_sizes): - """Print a summary of the available data.""" - csv_file = input_path - df = load_and_parse_data(csv_file) - - print("Data Summary:") - print(f"Total rows: {len(df)}") - print(f"Unique batch sizes: {sorted(df['batch_size'].unique())}") - print(f"Unique vocab sizes: {sorted(df['vocab_size'].unique())}") - print(f"Unique p values: {sorted([p for p in df['p'].unique() if p != 'nan'])}") - print(f"Unique k values: {sorted([k for k in df['k'].unique() if k != 'nan'])}") - print() - - print("Data availability matrix:") - print("Rows: Vocab sizes, Columns: Parameter combinations") - print("Values: Number of data points available") - print() - - header = f"{'Vocab Size':<12}" - for config in column_configs: - header += f"{config['title']:<15}" - print(header) - print("-" * len(header)) - - for vocab_size in vocab_sizes: - row = f"{vocab_size:<12}" - for config in column_configs: - data = get_filtered_data(df, vocab_size, config["p"], config["k"]) - row += f"{len(data):<15}" - print(row) - - -if __name__ == "__main__": - column_configs = [ - {"p": None, "k": 50, "title": "P=None, K=50"}, - {"p": 0.9, "k": None, "title": "P=0.9, K=None"}, - {"p": 0.9, "k": 50, "title": "P=0.9, K=50"}, - {"p": "RAND", "k": 3000, "title": "P=RAND, K=3000"}, - ] - vocab_sizes = [16384, 65536, 102400, 128256] - # Print data summary first - print_data_summary(column_configs, vocab_sizes) - print("\n" + "=" * 80 + "\n") - - # Create the plots - create_speedup_plots(column_configs, vocab_sizes) From d2d56a12e238f557230f4689c70381254bb48970 Mon Sep 17 00:00:00 2001 From: js_park Date: Tue, 18 Nov 2025 16:03:47 -0800 Subject: [PATCH 74/99] Change hyperparameters Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 6ce6466d89af..92de3991a970 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -895,7 +895,7 @@ def apply_top_k_top_p_filtered( # which consists of very small probabilities which has larger relative errors # compared to the original PyTorch top-p probabilities. As such, we fallback to # the original PyTorch top-p implementation for accuracy when p is too large. - if max_k > vocab_size / 4 or (k is None and p.max().item() > 0.995): + if max_k > vocab_size / 4 or (k is None and p.max().item() > 0.99): return apply_top_k_top_p(logits, k, p) BLOCK_SIZE = 8192 From 7643eabd306a17eaa1b72e39ffba3f65c47cc0c4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 18 Jan 2026 10:11:01 -0800 Subject: [PATCH 75/99] [Perf] Triton-based top-p/top-k masking Signed-off-by: Nick Hill --- benchmarks/benchmark_topk_topp.py | 448 ++++++++++++++++++++++ tests/v1/sample/test_topk_topp_sampler.py | 184 ++++++++- vllm/v1/sample/ops/topk_topp_sampler.py | 56 ++- vllm/v1/sample/ops/topk_topp_triton.py | 377 ++++++++++++++++++ 4 files changed, 1040 insertions(+), 25 deletions(-) create mode 100644 benchmarks/benchmark_topk_topp.py create mode 100644 vllm/v1/sample/ops/topk_topp_triton.py diff --git a/benchmarks/benchmark_topk_topp.py b/benchmarks/benchmark_topk_topp.py new file mode 100644 index 000000000000..6191f2902c73 --- /dev/null +++ b/benchmarks/benchmark_topk_topp.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark comparing Triton vs PyTorch sort-based top-k/top-p implementations. + +Compares: +- apply_top_k_top_p_triton (Triton binary search) +- apply_top_k_top_p (PyTorch sort-based) + +Scenarios: +- top_k only (whole batch, partial batch) +- top_p only (whole batch, partial batch) +- mix of top_k and top_p +""" + +import argparse +import gc +from dataclasses import dataclass + +import torch + +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_pytorch +from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + +@dataclass +class BenchmarkConfig: + """Configuration for a benchmark run.""" + + name: str + batch_size: int + vocab_size: int + # k and p can be tensors or None + k_values: torch.Tensor | None # [batch_size] or None + p_values: torch.Tensor | None # [batch_size] or None + description: str + ops_pct: float = 0.0 # Percentage of ops relative to batch size + + +def calculate_ops_pct( + k_values: torch.Tensor | None, + p_values: torch.Tensor | None, + vocab_size: int, + batch_size: int, +) -> float: + """ + Calculate the percentage of active top-k and top-p operations. + + Returns percentage where 100% = batch_size ops. + E.g., if all rows have both top-k and top-p active, returns 200%. + """ + active_ops = 0 + + if k_values is not None: + # Count rows where k < vocab_size (active top-k filtering) + active_ops += (k_values < vocab_size).sum().item() + + if p_values is not None: + # Count rows where p < 1.0 (active top-p filtering) + active_ops += (p_values < 1.0).sum().item() + + return (active_ops / batch_size) * 100 if batch_size > 0 else 0.0 + + +def create_logits( + batch_size: int, vocab_size: int, device: str = "cuda" +) -> torch.Tensor: + """Create random logits tensor.""" + return torch.randn(batch_size, vocab_size, dtype=torch.float32, device=device) + + +def measure_memory() -> tuple[int, int]: + """Return (allocated, reserved) memory in bytes.""" + torch.cuda.synchronize() + return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated() + + +def reset_memory_stats(): + """Reset peak memory statistics.""" + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + gc.collect() + + +def benchmark_function( + func, + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + warmup_iters: int = 5, + benchmark_iters: int = 20, +) -> tuple[float, int]: + """ + Benchmark a function and return (avg_time_ms, peak_memory_bytes). + + Returns average time in milliseconds and peak memory usage. + """ + # Warmup + for _ in range(warmup_iters): + logits_copy = logits.clone() + func(logits_copy, k, p) + torch.cuda.synchronize() + + # Reset memory stats before benchmark + reset_memory_stats() + + # Benchmark + start_events = [ + torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters) + ] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)] + + for i in range(benchmark_iters): + logits_copy = logits.clone() + start_events[i].record() + func(logits_copy, k, p) + end_events[i].record() + + torch.cuda.synchronize() + + # Calculate timing + times = [ + start_events[i].elapsed_time(end_events[i]) for i in range(benchmark_iters) + ] + avg_time = sum(times) / len(times) + + # Get peak memory + _, peak_memory = measure_memory() + + return avg_time, peak_memory + + +def create_benchmark_configs( + batch_sizes: list[int], + vocab_sizes: list[int], + device: str = "cuda", +) -> list[BenchmarkConfig]: + """Create all benchmark configurations.""" + configs = [] + + for vocab_size in vocab_sizes: + for batch_size in batch_sizes: + # 1. Top-k only - whole batch (all rows have k < vocab_size) + k_all = torch.full((batch_size,), 50, dtype=torch.int32, device=device) + configs.append( + BenchmarkConfig( + name=f"topk_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_all, + p_values=None, + description=f"Top-k only (whole batch, k=50), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_all, None, vocab_size, batch_size), + ) + ) + + # 2. Top-k only - partial batch (half have k=50, half have k=vocab_size) + k_partial = torch.full((batch_size,), 50, dtype=torch.int32, device=device) + k_partial[batch_size // 2 :] = vocab_size # No filtering for second half + configs.append( + BenchmarkConfig( + name=f"topk_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_partial, + p_values=None, + description=f"Top-k only (partial batch, 50% k=50, 50% k=vocab), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_partial, None, vocab_size, batch_size), + ) + ) + + # 3. Top-p only - whole batch (all rows have p < 1.0) + p_all = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device) + configs.append( + BenchmarkConfig( + name=f"topp_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=None, + p_values=p_all, + description=f"Top-p only (whole batch, p=0.9), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(None, p_all, vocab_size, batch_size), + ) + ) + + # 4. Top-p only - partial batch (half have p=0.9, half have p=1.0) + p_partial = torch.full( + (batch_size,), 0.9, dtype=torch.float32, device=device + ) + p_partial[batch_size // 2 :] = 1.0 # No filtering for second half + configs.append( + BenchmarkConfig( + name=f"topp_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=None, + p_values=p_partial, + description=f"Top-p only (partial batch, 50% p=0.9, 50% p=1.0), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(None, p_partial, vocab_size, batch_size), + ) + ) + + # 5. Mix of top-k and top-p (both applied to whole batch) + k_mix = torch.full((batch_size,), 100, dtype=torch.int32, device=device) + p_mix = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device) + configs.append( + BenchmarkConfig( + name=f"topk_topp_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_mix, + p_values=p_mix, + description=f"Top-k + Top-p (whole batch, k=100, p=0.9), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_mix, p_mix, vocab_size, batch_size), + ) + ) + + # 6. Mix with partial application (some rows k only, some p only, some both) + k_mixed = torch.full( + (batch_size,), vocab_size, dtype=torch.int32, device=device + ) + p_mixed = torch.full((batch_size,), 1.0, dtype=torch.float32, device=device) + # First third: k only + third = batch_size // 3 + k_mixed[:third] = 50 + # Second third: p only + p_mixed[third : 2 * third] = 0.9 + # Last third: both k and p + k_mixed[2 * third :] = 100 + p_mixed[2 * third :] = 0.9 + configs.append( + BenchmarkConfig( + name=f"mixed_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_mixed, + p_values=p_mixed, + description=f"Mixed partial (1/3 k=50, 1/3 p=0.9, 1/3 both), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_mixed, p_mixed, vocab_size, batch_size), + ) + ) + + return configs + + +def format_memory(bytes_val: int) -> str: + """Format memory in human-readable form.""" + if bytes_val >= 1024**3: + return f"{bytes_val / (1024**3):.2f} GB" + elif bytes_val >= 1024**2: + return f"{bytes_val / (1024**2):.2f} MB" + elif bytes_val >= 1024: + return f"{bytes_val / 1024:.2f} KB" + return f"{bytes_val} B" + + +def run_benchmark( + configs: list[BenchmarkConfig], + warmup_iters: int = 5, + benchmark_iters: int = 20, + verbose: bool = True, +): + """Run all benchmarks and print results.""" + results = [] + + print("=" * 100) + print("Top-k/Top-p Benchmark: Triton vs PyTorch Sort-based") + print("=" * 100) + print() + + for config in configs: + if verbose: + print(f"Running: {config.description}") + + # Create fresh logits for this config + logits = create_logits(config.batch_size, config.vocab_size) + + # Benchmark Triton + reset_memory_stats() + triton_time, triton_mem = benchmark_function( + apply_top_k_top_p_triton, + logits, + config.k_values, + config.p_values, + warmup_iters, + benchmark_iters, + ) + + # Benchmark PyTorch + reset_memory_stats() + pytorch_time, pytorch_mem = benchmark_function( + apply_top_k_top_pytorch, + logits, + config.k_values, + config.p_values, + warmup_iters, + benchmark_iters, + ) + + speedup = pytorch_time / triton_time if triton_time > 0 else float("inf") + mem_ratio = pytorch_mem / triton_mem if triton_mem > 0 else float("inf") + + result = { + "config": config, + "triton_time_ms": triton_time, + "pytorch_time_ms": pytorch_time, + "triton_mem": triton_mem, + "pytorch_mem": pytorch_mem, + "speedup": speedup, + "mem_ratio": mem_ratio, + } + results.append(result) + + if verbose: + print(f" Triton: {triton_time:.3f} ms, {format_memory(triton_mem)}") + print(f" PyTorch: {pytorch_time:.3f} ms, {format_memory(pytorch_mem)}") + print(f" Speedup: {speedup:.2f}x, Memory ratio: {mem_ratio:.2f}x") + print() + + # Clean up + del logits + reset_memory_stats() + + return results + + +def print_summary_table(results: list[dict]): + """Print a summary table of results.""" + print() + print("=" * 130) + print("SUMMARY TABLE") + print("=" * 130) + print() + + # Header + header = ( + f"{'Scenario':<40} {'Batch':>6} {'Vocab':>7} {'Ops%':>6} " + f"{'Triton (ms)':>12} {'PyTorch (ms)':>13} {'Speedup':>8} " + f"{'Tri Mem':>10} {'Pyt Mem':>10}" + ) + print(header) + print("-" * 130) + + # Group by scenario type + current_vocab = None + for result in results: + config = result["config"] + + # Add separator between vocab sizes + if current_vocab != config.vocab_size: + if current_vocab is not None: + print("-" * 130) + current_vocab = config.vocab_size + + scenario = config.name.split("_b")[0] # Extract scenario name + print( + f"{scenario:<40} {config.batch_size:>6} {config.vocab_size:>7} " + f"{config.ops_pct:>5.0f}% " + f"{result['triton_time_ms']:>12.3f} {result['pytorch_time_ms']:>13.3f} " + f"{result['speedup']:>7.2f}x " + f"{format_memory(result['triton_mem']):>10} " + f"{format_memory(result['pytorch_mem']):>10}" + ) + + print("=" * 130) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark Triton vs PyTorch sort-based top-k/top-p implementations" + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=[1, 4, 16, 24, 32, 48, 56, 64, 96, 128, 192, 256, 512, 1024], + help="Batch sizes to test (default: 1 4 16 64)", + ) + parser.add_argument( + "--vocab-sizes", + type=int, + nargs="+", + default=[32768, 131072], # 32k, 128k + help="Vocabulary sizes to test (default: 32768 131072)", + ) + parser.add_argument( + "--warmup-iters", + type=int, + default=5, + help="Number of warmup iterations (default: 5)", + ) + parser.add_argument( + "--benchmark-iters", + type=int, + default=20, + help="Number of benchmark iterations (default: 20)", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Only print summary table", + ) + + args = parser.parse_args() + + # Print configuration + print(f"Batch sizes: {args.batch_sizes}") + print(f"Vocab sizes: {args.vocab_sizes}") + print(f"Warmup iterations: {args.warmup_iters}") + print(f"Benchmark iterations: {args.benchmark_iters}") + print() + + # Check CUDA + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available. This benchmark requires a GPU.") + return + + device_name = torch.cuda.get_device_name(0) + print(f"GPU: {device_name}") + print() + + # Create configs + configs = create_benchmark_configs( + args.batch_sizes, + args.vocab_sizes, + ) + + # Run benchmarks + results = run_benchmark( + configs, + warmup_iters=args.warmup_iters, + benchmark_iters=args.benchmark_iters, + verbose=not args.quiet, + ) + + # Print summary + print_summary_table(results) + + +if __name__ == "__main__": + main() diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index f50ef6102204..46b551501bd2 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -5,8 +5,9 @@ from torch import Generator from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_pytorch +CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None DEVICE = current_platform.device_type BATCH_SIZE = 1024 @@ -39,11 +40,11 @@ def test_topk_impl_equivalence(): ) # Top-k only implementation - result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + result1 = apply_top_k_top_pytorch(logits=logits.clone(), k=k, p=None) # Top-p + top-k no_op_top_p = torch.tensor([1.0]) - result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) + result2 = apply_top_k_top_pytorch(logits=logits.clone(), k=k, p=no_op_top_p) assert torch.allclose(result1, result2) @@ -93,7 +94,7 @@ def test_flashinfer_sampler(): torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0 ) - python_logits = apply_top_k_top_p( + python_logits = apply_top_k_top_pytorch( logits=logits.clone(), k=k_values, p=p_values, @@ -115,3 +116,178 @@ def test_flashinfer_sampler(): assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), ( "FlashInfer and Python sampling implementations do not match!" ) + + +# ============================================================================= +# Triton kernel tests +# ============================================================================= + + +@pytest.mark.skipif(CUDA_DEVICE is None, reason="CUDA not available") +class TestTritonTopkTopp: + """Tests for the Triton top-k/top-p kernel.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Set up test fixtures.""" + torch.set_default_device(CUDA_DEVICE) + self.generator = Generator(device=CUDA_DEVICE).manual_seed(42) + + def _compare_results( + self, + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + ): + """Compare Triton kernel results with PyTorch sorting implementation. + + For top-k only, we expect exact match. + For top-p (with or without top-k), we allow small differences due to + floating-point precision in probability sum calculations. + """ + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + # Clone logits for both implementations + logits_pytorch = logits.clone() + logits_triton = logits.clone().to(torch.float32) + + # Apply PyTorch sorting implementation + result_pytorch = apply_top_k_top_pytorch(logits_pytorch, k, p) + + # Apply Triton kernel + k_i32 = k.to(torch.int32) if k is not None else None + p_f32 = p.to(torch.float32) if p is not None else None + result_triton = apply_top_k_top_p_triton(logits_triton, k_i32, p_f32) + + # Compare kept counts per row + pytorch_kept = (result_pytorch != float("-inf")).sum(dim=-1) + triton_kept = (result_triton != float("-inf")).sum(dim=-1) + + if p is None: + # Top-k only: expect exact match + assert torch.equal(pytorch_kept, triton_kept), ( + f"Top-k mask mismatch: PyTorch kept {pytorch_kept.tolist()}, " + f"Triton kept {triton_kept.tolist()}" + ) + else: + # Top-p involved: allow small differences + # Either < 1% of kept values OR < 5 values absolute + max_diff = (pytorch_kept - triton_kept).abs().max().item() + max_kept = pytorch_kept.max().item() + if max_kept > 0 and max_diff > 3: + diff_pct = max_diff / max_kept * 100 + assert diff_pct < 0.5, ( + f"Top-p mask difference too large: {diff_pct:.2f}% " + f"(max diff {max_diff} values out of {max_kept})" + ) + + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + def test_topk_only(self, batch_size: int, vocab_size: int): + """Test top-k only (p=None).""" + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + k = torch.randint( + 1, min(100, vocab_size), (batch_size,), generator=self.generator + ) + # Randomly disable top-k for some rows (~25%) + disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + k.masked_fill_(disable_mask, vocab_size) + + self._compare_results(logits, k, p=None) + + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + def test_topp_only(self, batch_size: int, vocab_size: int): + """Test top-p only (k=None).""" + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0] + # Randomly disable top-p for some rows (~25%) + disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + p.masked_fill_(disable_mask, 1.0) + + self._compare_results(logits, k=None, p=p) + + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + def test_topk_and_topp(self, batch_size: int, vocab_size: int): + """Test combined top-k and top-p.""" + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + k = torch.randint( + 1, min(100, vocab_size), (batch_size,), generator=self.generator + ) + p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0] + + # Randomly disable top-k for some rows (~25%) + disable_k = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + k.masked_fill_(disable_k, vocab_size) + # Randomly disable top-p for some rows (~25%) + disable_p = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + p.masked_fill_(disable_p, 1.0) + + self._compare_results(logits, k, p) + + def test_both_disabled(self): + """Test when both k and p are None (should be no-op).""" + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32) + logits_clone = logits.clone() + + result = apply_top_k_top_p_triton(logits_clone, k=None, p=None) + + assert torch.equal(result, logits), "Should be no-op when both k and p are None" + + def test_extreme_k_values(self): + """Test edge cases for k values.""" + batch_size, vocab_size = 16, 1024 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + + # k=1 (keep only top 1) + k = torch.ones(batch_size, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + # k=vocab_size (keep all) + k = torch.full((batch_size,), vocab_size, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + # Mixed extreme values + k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + def test_extreme_p_values(self): + """Test edge cases for p values.""" + batch_size, vocab_size = 16, 1024 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + + # p close to 0 (very restrictive) + p = torch.full((batch_size,), 0.01, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + # p=1.0 (keep all) + p = torch.ones(batch_size, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + # Mixed values + p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + def test_large_batch(self): + """Test with a large batch size.""" + batch_size, vocab_size = 512, 32000 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + k = torch.randint(1, 50, (batch_size,), generator=self.generator) + p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5 + + self._compare_results(logits, k, p) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 03da3e565e49..8459667ef2ea 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -11,6 +11,7 @@ from vllm.config.model import LogprobsMode from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform +from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton logger = init_logger(__name__) @@ -87,8 +88,6 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: else: self.forward = self.forward_native - self.apply_top_k_top_p = apply_top_k_top_p - def forward_native( self, logits: torch.Tensor, @@ -101,7 +100,7 @@ def forward_native( The logits tensor may be updated in-place. """ - logits = self.apply_top_k_top_p(logits, k, p) + logits = apply_top_k_top_p(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": logits_to_return = logits @@ -149,7 +148,7 @@ def forward_cpu( The logits tensor may be updated in-place. """ - logits = self.apply_top_k_top_p(logits, k, p) + logits = apply_top_k_top_pytorch(logits, k, p, allow_cpu_sync=True) logits_to_return = None if self.logprobs_mode == "processed_logits": logits_to_return = logits @@ -158,14 +157,14 @@ def forward_cpu( if len(generators) != logits.shape[0]: return compiled_random_sample(logits), logits_to_return - else: - probs = logits.softmax(dim=-1, dtype=torch.float32) - q = torch.empty_like(probs) - q.exponential_() - for i, generator in generators.items(): - q[i].exponential_(generator=generator) - return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return + probs = logits.softmax(dim=-1, dtype=torch.float32) + q = torch.empty_like(probs) + q.exponential_() + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + + return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return def forward_hip( self, @@ -241,9 +240,26 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: def apply_top_k_top_p( + logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None +) -> torch.Tensor: + if p is None and k is None: + return logits + + # Rough empirical heuristic + batch_size, vocab_size = logits.shape + both_k_and_p = p is not None and k is not None + threshold = vocab_size // (1024 if both_k_and_p else 2048) + if batch_size < threshold: + # Use pytorch sort implementation for smaller batch sizes. + return apply_top_k_top_pytorch(logits, k, p) + return apply_top_k_top_p_triton(logits, k, p) + + +def apply_top_k_top_pytorch( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, + allow_cpu_sync: bool = False, ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. @@ -256,8 +272,9 @@ def apply_top_k_top_p( if k is None: return logits - # Avoid sorting vocab for top-k only case. - return apply_top_k_only(logits, k) + if allow_cpu_sync: + # Avoid sorting vocab for top-k only case. + return apply_top_k_only(logits, k) logits_sort, logits_idx = logits.sort(dim=-1, descending=False) @@ -279,18 +296,16 @@ def apply_top_k_top_p( logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits + return logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) -def apply_top_k_only( - logits: torch.Tensor, - k: torch.Tensor, -) -> torch.Tensor: +def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor: """ Apply top-k mask to the logits. This implementation doesn't involve sorting the entire vocab. + Note however that it involves a GPU->CPU sync which can be detrimental for + async scheduling performance. The logits tensor may be updated in-place. """ @@ -304,8 +319,7 @@ def apply_top_k_only( top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) - logits.masked_fill_(logits < top_k_mask, -float("inf")) - return logits + return logits.masked_fill_(logits < top_k_mask, -float("inf")) def random_sample( diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py new file mode 100644 index 000000000000..7206a2356bd8 --- /dev/null +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Combined Top-K and Top-P Triton kernels. + +These kernels apply top-k filtering first, then top-p on the remaining values. +This is more efficient than sorting the entire vocabulary. + +Algorithm: +1. Find k-th largest logit using binary search → top-k threshold +2. Mask logits below threshold, compute softmax (only k values contribute) +3. Find probability threshold for top-p using binary search +4. Apply final mask + +Complexity: O(vocab_size * (k_iters + p_iters)) where iters ≈ 16-20 +""" + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _topk_topp_kernel( + # Input/output + logits_ptr, + # Parameters per row + k_ptr, + p_ptr, + # Dimensions + logits_stride: tl.constexpr, + vocab_size: tl.constexpr, + # Mask value + mask_value: tl.constexpr, + # Algorithm parameters + BLOCK_SIZE: tl.constexpr, + K_ITERS: tl.constexpr, + P_ITERS: tl.constexpr, + # Feature flags (when False, use default values instead of loading) + TOPK_ENABLED: tl.constexpr, + TOPP_ENABLED: tl.constexpr, +): + """ + Combined top-k and top-p masking kernel. + + Applies top-k first (by logit value), then top-p (by probability). + Optimized to skip softmax computation when p >= 1.0. + """ + row_idx = tl.program_id(0) + + if TOPK_ENABLED: # noqa: SIM108 + k = tl.load(k_ptr + row_idx) + else: + k = vocab_size # Default: keep all (no top-k filtering) + + if TOPP_ENABLED: # noqa: SIM108 + p = tl.load(p_ptr + row_idx) + else: + p = 1.0 # Default: keep all (no top-p filtering) + + row_ptr = logits_ptr + row_idx * logits_stride + + # Determine which operations to apply + apply_topk = k < vocab_size + apply_topp = p < 1.0 + + # Early exit if nothing to do + if (not apply_topk) and (not apply_topp): + return + + # ========================================================================= + # Phase 1: Find top-k threshold using binary search on logits + # OPTIMIZATION: Fuse min/max finding with first binary search iteration + # by counting values > 0 during min/max pass (saves 1 memory pass) + # ========================================================================= + + topk_threshold = float("-inf") + + if apply_topk: + # Fused pass: find min/max AND count values > 0 (first binary search step) + max_logit = float("-inf") + min_logit = float("inf") + count_above_zero = tl.zeros([1], dtype=tl.int32) + + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + max_logit = tl.maximum(max_logit, tl.max(vals)) + vals_min = tl.where(mask, vals, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(vals_min)) + # Count values > 0 (fused first binary search iteration) + count_above_zero += tl.sum((vals > 0.0).to(tl.int32)) + + # Use count_above_zero to set initial bounds (equivalent to first iteration) + # If count_above_zero >= k, the k-th largest is > 0, so raise lo to 0 + # Otherwise, the k-th largest is <= 0, so lower hi to 0 + if tl.sum(count_above_zero) >= k: + lo = 0.0 + hi = max_logit + else: + lo = min_logit + hi = 0.0 + + # Continue with remaining K_ITERS-1 binary search iterations + for _ in range(K_ITERS - 1): + mid = (lo + hi) * 0.5 + count_gt = tl.zeros([1], dtype=tl.int32) + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + count_gt += tl.sum((vals > mid).to(tl.int32)) + if tl.sum(count_gt) >= k: + lo = mid + else: + hi = mid + + # Refine to exact k-th largest value. + count_gt_lo = tl.zeros([1], dtype=tl.int32) + min_above_lo = float("inf") + max_at_or_below_hi = float("-inf") + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + count_gt_lo += tl.sum((vals > lo).to(tl.int32)) + vals_above_lo = tl.where(vals > lo, vals, float("inf")) + min_above_lo = tl.minimum(min_above_lo, tl.min(vals_above_lo)) + vals_at_or_below_hi = tl.where(vals <= hi, vals, float("-inf")) + max_at_or_below_hi = tl.maximum( + max_at_or_below_hi, tl.max(vals_at_or_below_hi) + ) + + if tl.sum(count_gt_lo) == k: + topk_threshold = min_above_lo + else: + topk_threshold = max_at_or_below_hi + + # ========================================================================= + # If no top-p, apply top-k mask and return early + # ========================================================================= + + if not apply_topp: + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + result = tl.where(vals >= topk_threshold, vals, mask_value) + tl.store(row_ptr + offs, result, mask=mask) + return + + # ========================================================================= + # Phase 2: Compute softmax using online softmax (single pass) + # ========================================================================= + # Online softmax computes max and exp_sum in one pass by rescaling + # the running sum when a new max is found. + # + # Key insight: We need to handle the case where softmax_max is -inf + # (no valid values seen yet). In this case, -inf - (-inf) = nan, + # so we must skip blocks with no valid values. + + softmax_max = float("-inf") + exp_sum = 0.0 + + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + # Apply top-k mask + vals = tl.where(vals >= topk_threshold, vals, float("-inf")) + + # Find block max + block_max = tl.max(vals) + + # Skip blocks with no valid values (all -inf) + # This avoids nan from -inf - (-inf) + if block_max > float("-inf"): + # Update running max and rescale sum if needed + new_max = tl.maximum(softmax_max, block_max) + + # Rescale previous sum: sum * exp(old_max - new_max) + # When softmax_max is -inf (first valid block), exp(-inf - finite) = 0, + # which is correct since exp_sum starts at 0. + exp_sum = exp_sum * tl.exp(softmax_max - new_max) + softmax_max = new_max + + # Add current block's contribution (normalized by new max) + exp_sum += tl.sum(tl.exp(vals - softmax_max)) + + log_exp_sum = tl.log(exp_sum) + + # ========================================================================= + # Phase 3: Find top-p threshold using binary search on probabilities + # OPTIMIZATION: Fuse min/max finding with first binary search iteration + # by computing prob mass > 0.5 during min/max pass (saves 1 memory pass) + # ========================================================================= + + # Fused pass: find min/max log-probs AND sum probs > 0.5 (first iteration) + max_log_prob = float("-inf") + min_log_prob = float("inf") + log_half = -0.6931471805599453 # log(0.5) + prob_sum_above_half = 0.0 + + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + + # Only consider top-k values + is_topk = vals >= topk_threshold + + # log_prob = logit - softmax_max - log(exp_sum) + log_probs = vals - softmax_max - log_exp_sum + + log_probs_masked = tl.where(is_topk, log_probs, float("-inf")) + max_log_prob = tl.maximum(max_log_prob, tl.max(log_probs_masked)) + + log_probs_for_min = tl.where(is_topk & mask, log_probs, float("inf")) + min_log_prob = tl.minimum(min_log_prob, tl.min(log_probs_for_min)) + + # Sum probability mass above 0.5 (fused first binary search iteration) + probs = tl.exp(log_probs) + above_half = (log_probs > log_half) & is_topk + prob_sum_above_half += tl.sum(tl.where(above_half, probs, 0.0)) + + # Use prob_sum_above_half to set initial bounds (equivalent to first iteration) + if prob_sum_above_half >= p: + lo_lp = log_half + hi_lp = max_log_prob + else: + lo_lp = min_log_prob + hi_lp = log_half + + # Continue with remaining P_ITERS-1 binary search iterations + for _ in range(P_ITERS - 1): + mid_lp = (lo_lp + hi_lp) * 0.5 + + # Sum probabilities strictly > mid_lp + prob_sum_gt = 0.0 + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + + is_topk = vals >= topk_threshold + log_probs = vals - softmax_max - log_exp_sum + probs = tl.exp(log_probs) + + # Only sum probs that are strictly > threshold and in top-k + above_threshold = (log_probs > mid_lp) & is_topk + prob_sum_gt += tl.sum(tl.where(above_threshold, probs, 0.0)) + + # If sum of probs strictly above mid >= p, raise threshold + if prob_sum_gt >= p: + lo_lp = mid_lp + else: + hi_lp = mid_lp + + # Refine to exact threshold using combined approach (same as top-k). + # After binary search: prob_sum(> lo_lp) >= p, prob_sum(> hi_lp) < p. + # Count how many distinct log-probs are > lo_lp to determine which refinement. + count_gt_lo_lp = tl.zeros([1], dtype=tl.int32) + min_lp_above_lo = float("inf") + max_lp_at_or_below_hi = float("-inf") + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + + is_topk = vals >= topk_threshold + log_probs = vals - softmax_max - log_exp_sum + + above_lo = is_topk & (log_probs > lo_lp) + count_gt_lo_lp += tl.sum(above_lo.to(tl.int32)) + + lp_above_lo = tl.where(above_lo, log_probs, float("inf")) + min_lp_above_lo = tl.minimum(min_lp_above_lo, tl.min(lp_above_lo)) + + at_or_below_hi = is_topk & (log_probs <= hi_lp) + lp_at_or_below_hi = tl.where(at_or_below_hi, log_probs, float("-inf")) + max_lp_at_or_below_hi = tl.maximum( + max_lp_at_or_below_hi, tl.max(lp_at_or_below_hi) + ) + + # For top-p, use min if there are values > lo, otherwise use max. + # This handles edge cases where lo/hi converge to the same side. + if tl.sum(count_gt_lo_lp) > 0 and min_lp_above_lo < float("inf"): + topp_log_threshold = min_lp_above_lo + else: + topp_log_threshold = max_lp_at_or_below_hi + + # ========================================================================= + # Phase 4: Apply combined mask + # ========================================================================= + + for i in range(0, vocab_size, BLOCK_SIZE): + offs = i + tl.arange(0, BLOCK_SIZE) + mask = offs < vocab_size + vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) + + # Apply top-k mask + keep = vals >= topk_threshold + + # Apply top-p mask + log_probs = vals - softmax_max - log_exp_sum + keep = keep & (log_probs >= topp_log_threshold) + + result = tl.where(keep, vals, mask_value) + tl.store(row_ptr + offs, result, mask=mask) + + +def apply_top_k_top_p_triton( + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + mask_value: float = float("-inf"), +) -> torch.Tensor: + """ + Apply combined top-k and top-p masking using Triton. + + Top-k is applied first (by logit value), then top-p is applied + to the remaining k values (by probability). + + Args: + logits: [n, vocab_size] float32 tensor, modified in-place + k: [n] int32 tensor of top-k values per row, or None to disable top-k + p: [n] float32 tensor of top-p values per row (0 to 1), + or None to disable top-p + mask_value: Value for masked positions (default: -inf) + + Returns: + The logits tensor (modified in-place) + """ + assert logits.ndim == 2 + assert logits.dtype == torch.float32 + assert logits.is_cuda + + n, vocab_size = logits.shape + + topk_enabled = k is not None + topp_enabled = p is not None + + if n == 0 or (k is None and p is None): + return logits + + if k is not None: + assert k.ndim == 1 and k.shape[0] == n and k.is_cuda + k_ptr = k.to(torch.int32) + else: + k_ptr = logits # Dummy pointer (won't be read) + + if p is not None: + assert p.ndim == 1 and p.shape[0] == n and p.is_cuda + p_ptr = p.to(torch.float32) + else: + p_ptr = logits # Dummy pointer (won't be read) + + BLOCK_SIZE = 1024 + K_ITERS = 16 + P_ITERS = 16 # TODO or 12 + + _topk_topp_kernel[(n,)]( + logits, + k_ptr, + p_ptr, + logits_stride=logits.stride(0), + vocab_size=vocab_size, + mask_value=mask_value, + BLOCK_SIZE=BLOCK_SIZE, + K_ITERS=K_ITERS, + P_ITERS=P_ITERS, + TOPK_ENABLED=topk_enabled, + TOPP_ENABLED=topp_enabled, + ) + + return logits From 5a241a69304a8e397bbfb85e19bfc243130069c6 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 18 Jan 2026 21:48:28 -0800 Subject: [PATCH 76/99] fix doc Signed-off-by: Nick Hill --- vllm/v1/sample/ops/topk_topp_triton.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 7206a2356bd8..11b1f015323a 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -48,26 +48,28 @@ def _topk_topp_kernel( """ row_idx = tl.program_id(0) - if TOPK_ENABLED: # noqa: SIM108 + if TOPK_ENABLED: k = tl.load(k_ptr + row_idx) + apply_topk = k < vocab_size else: - k = vocab_size # Default: keep all (no top-k filtering) + # Default: keep all (no top-k filtering) + k = vocab_size + apply_topk = False - if TOPP_ENABLED: # noqa: SIM108 + if TOPP_ENABLED: p = tl.load(p_ptr + row_idx) + apply_topp = p < 1.0 else: - p = 1.0 # Default: keep all (no top-p filtering) - - row_ptr = logits_ptr + row_idx * logits_stride - - # Determine which operations to apply - apply_topk = k < vocab_size - apply_topp = p < 1.0 + # Default: keep all (no top-p filtering) + p = 1.0 + apply_topp = False # Early exit if nothing to do if (not apply_topk) and (not apply_topp): return + row_ptr = logits_ptr + row_idx * logits_stride + # ========================================================================= # Phase 1: Find top-k threshold using binary search on logits # OPTIMIZATION: Fuse min/max finding with first binary search iteration @@ -326,7 +328,7 @@ def apply_top_k_top_p_triton( logits: [n, vocab_size] float32 tensor, modified in-place k: [n] int32 tensor of top-k values per row, or None to disable top-k p: [n] float32 tensor of top-p values per row (0 to 1), - or None to disable top-p + or None to disable top-p mask_value: Value for masked positions (default: -inf) Returns: @@ -341,7 +343,7 @@ def apply_top_k_top_p_triton( topk_enabled = k is not None topp_enabled = p is not None - if n == 0 or (k is None and p is None): + if n == 0 or not (topk_enabled or topp_enabled): return logits if k is not None: From b017713d1681dcb681c1241a54b46e58ecc27163 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 20 Jan 2026 17:35:37 -0800 Subject: [PATCH 77/99] fix method name, only use triton when supported Signed-off-by: Nick Hill --- benchmarks/benchmark_topk_topp.py | 4 ++-- tests/v1/sample/test_topk_topp_sampler.py | 10 ++++----- vllm/v1/sample/ops/topk_topp_sampler.py | 25 ++++++++++++++--------- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/benchmarks/benchmark_topk_topp.py b/benchmarks/benchmark_topk_topp.py index 6191f2902c73..dae0458d01ee 100644 --- a/benchmarks/benchmark_topk_topp.py +++ b/benchmarks/benchmark_topk_topp.py @@ -20,7 +20,7 @@ import torch -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_pytorch +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton @@ -296,7 +296,7 @@ def run_benchmark( # Benchmark PyTorch reset_memory_stats() pytorch_time, pytorch_mem = benchmark_function( - apply_top_k_top_pytorch, + apply_top_k_top_p_pytorch, logits, config.k_values, config.p_values, diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index 46b551501bd2..257c8735c461 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -5,7 +5,7 @@ from torch import Generator from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_pytorch +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None DEVICE = current_platform.device_type @@ -40,11 +40,11 @@ def test_topk_impl_equivalence(): ) # Top-k only implementation - result1 = apply_top_k_top_pytorch(logits=logits.clone(), k=k, p=None) + result1 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=None) # Top-p + top-k no_op_top_p = torch.tensor([1.0]) - result2 = apply_top_k_top_pytorch(logits=logits.clone(), k=k, p=no_op_top_p) + result2 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=no_op_top_p) assert torch.allclose(result1, result2) @@ -94,7 +94,7 @@ def test_flashinfer_sampler(): torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0 ) - python_logits = apply_top_k_top_pytorch( + python_logits = apply_top_k_top_p_pytorch( logits=logits.clone(), k=k_values, p=p_values, @@ -152,7 +152,7 @@ def _compare_results( logits_triton = logits.clone().to(torch.float32) # Apply PyTorch sorting implementation - result_pytorch = apply_top_k_top_pytorch(logits_pytorch, k, p) + result_pytorch = apply_top_k_top_p_pytorch(logits_pytorch, k, p) # Apply Triton kernel k_i32 = k.to(torch.int32) if k is not None else None diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 8459667ef2ea..eaf6dcc3c520 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -11,7 +11,10 @@ from vllm.config.model import LogprobsMode from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform -from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton logger = init_logger(__name__) @@ -148,7 +151,7 @@ def forward_cpu( The logits tensor may be updated in-place. """ - logits = apply_top_k_top_pytorch(logits, k, p, allow_cpu_sync=True) + logits = apply_top_k_top_p_pytorch(logits, k, p, allow_cpu_sync=True) logits_to_return = None if self.logprobs_mode == "processed_logits": logits_to_return = logits @@ -246,16 +249,18 @@ def apply_top_k_top_p( return logits # Rough empirical heuristic - batch_size, vocab_size = logits.shape - both_k_and_p = p is not None and k is not None - threshold = vocab_size // (1024 if both_k_and_p else 2048) - if batch_size < threshold: - # Use pytorch sort implementation for smaller batch sizes. - return apply_top_k_top_pytorch(logits, k, p) - return apply_top_k_top_p_triton(logits, k, p) + if HAS_TRITON: + batch_size, vocab_size = logits.shape + both_k_and_p = p is not None and k is not None + threshold = vocab_size // (1024 if both_k_and_p else 2048) + if batch_size >= threshold: + # Use pytorch sort implementation for smaller batch sizes. + return apply_top_k_top_p_triton(logits, k, p) + + return apply_top_k_top_p_pytorch(logits, k, p) -def apply_top_k_top_pytorch( +def apply_top_k_top_p_pytorch( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, From e067cbfd0b5044d77270612b9760f98fd408b28c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 23 Jan 2026 08:53:25 -0800 Subject: [PATCH 78/99] fix precision Signed-off-by: Nick Hill --- vllm/v1/sample/ops/topk_topp_triton.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 11b1f015323a..ac759944bc80 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -119,6 +119,8 @@ def _topk_topp_kernel( hi = mid # Refine to exact k-th largest value. + # After binary search: lo < k-th value <= hi (approximately). + # Find the actual logit values at these boundaries. count_gt_lo = tl.zeros([1], dtype=tl.int32) min_above_lo = float("inf") max_at_or_below_hi = float("-inf") @@ -359,8 +361,10 @@ def apply_top_k_top_p_triton( p_ptr = logits # Dummy pointer (won't be read) BLOCK_SIZE = 1024 - K_ITERS = 16 - P_ITERS = 16 # TODO or 12 + # K_ITERS must be large enough to distinguish adjacent logit values. + # With randn logits (range ~8), 20 iterations gives precision ~8/2^19 ≈ 1.5e-5 + K_ITERS = 18 + P_ITERS = 14 _topk_topp_kernel[(n,)]( logits, From 463afa652826fdb536970d442b2b4ed1c477cfd7 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 17:54:46 -0800 Subject: [PATCH 79/99] Copied topk + topp impl Signed-off-by: js_park --- tests/v1/sample/test_topk_topp_sampler.py | 176 ++--- vllm/v1/sample/ops/topk_topp_triton.py | 857 ++++++++++++++-------- 2 files changed, 642 insertions(+), 391 deletions(-) diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index fc6412949dd9..6c19bc179327 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -186,35 +186,35 @@ def _compare_results( f"(max diff {max_diff} values out of {max_kept})" ) - @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) - @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) - def test_topk_only(self, batch_size: int, vocab_size: int): - """Test top-k only (p=None).""" - logits = torch.randn( - batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - ) - k = torch.randint( - 1, min(100, vocab_size), (batch_size,), generator=self.generator - ) - # Randomly disable top-k for some rows (~25%) - disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 - k.masked_fill_(disable_mask, vocab_size) - - self._compare_results(logits, k, p=None) - - @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) - @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) - def test_topp_only(self, batch_size: int, vocab_size: int): - """Test top-p only (k=None).""" - logits = torch.randn( - batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - ) - p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0] - # Randomly disable top-p for some rows (~25%) - disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 - p.masked_fill_(disable_mask, 1.0) - - self._compare_results(logits, k=None, p=p) + # @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + # @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + # def test_topk_only(self, batch_size: int, vocab_size: int): + # """Test top-k only (p=None).""" + # logits = torch.randn( + # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + # ) + # k = torch.randint( + # 1, min(100, vocab_size), (batch_size,), generator=self.generator + # ) + # # Randomly disable top-k for some rows (~25%) + # disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + # k.masked_fill_(disable_mask, vocab_size) + + # self._compare_results(logits, k, p=None) + + # @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + # @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + # def test_topp_only(self, batch_size: int, vocab_size: int): + # """Test top-p only (k=None).""" + # logits = torch.randn( + # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + # ) + # p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0] + # # Randomly disable top-p for some rows (~25%) + # disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + # p.masked_fill_(disable_mask, 1.0) + + # self._compare_results(logits, k=None, p=p) @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) @@ -237,62 +237,62 @@ def test_topk_and_topp(self, batch_size: int, vocab_size: int): self._compare_results(logits, k, p) - def test_both_disabled(self): - """Test when both k and p are None (should be no-op).""" - from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton - - logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32) - logits_clone = logits.clone() - - result = apply_top_k_top_p_triton(logits_clone, k=None, p=None) - - assert torch.equal(result, logits), "Should be no-op when both k and p are None" - - def test_extreme_k_values(self): - """Test edge cases for k values.""" - batch_size, vocab_size = 16, 1024 - logits = torch.randn( - batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - ) - - # k=1 (keep only top 1) - k = torch.ones(batch_size, dtype=torch.int32) - self._compare_results(logits.clone(), k, p=None) - - # k=vocab_size (keep all) - k = torch.full((batch_size,), vocab_size, dtype=torch.int32) - self._compare_results(logits.clone(), k, p=None) - - # Mixed extreme values - k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32) - self._compare_results(logits.clone(), k, p=None) - - def test_extreme_p_values(self): - """Test edge cases for p values.""" - batch_size, vocab_size = 16, 1024 - logits = torch.randn( - batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - ) - - # p close to 0 (very restrictive) - p = torch.full((batch_size,), 0.01, dtype=torch.float32) - self._compare_results(logits.clone(), k=None, p=p) - - # p=1.0 (keep all) - p = torch.ones(batch_size, dtype=torch.float32) - self._compare_results(logits.clone(), k=None, p=p) - - # Mixed values - p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32) - self._compare_results(logits.clone(), k=None, p=p) - - def test_large_batch(self): - """Test with a large batch size.""" - batch_size, vocab_size = 512, 32000 - logits = torch.randn( - batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - ) - k = torch.randint(1, 50, (batch_size,), generator=self.generator) - p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5 - - self._compare_results(logits, k, p) + # def test_both_disabled(self): + # """Test when both k and p are None (should be no-op).""" + # from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + # logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32) + # logits_clone = logits.clone() + + # result = apply_top_k_top_p_triton(logits_clone, k=None, p=None) + + # assert torch.equal(result, logits), "Should be no-op when both k and p are None" + + # def test_extreme_k_values(self): + # """Test edge cases for k values.""" + # batch_size, vocab_size = 16, 1024 + # logits = torch.randn( + # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + # ) + + # # k=1 (keep only top 1) + # k = torch.ones(batch_size, dtype=torch.int32) + # self._compare_results(logits.clone(), k, p=None) + + # # k=vocab_size (keep all) + # k = torch.full((batch_size,), vocab_size, dtype=torch.int32) + # self._compare_results(logits.clone(), k, p=None) + + # # Mixed extreme values + # k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32) + # self._compare_results(logits.clone(), k, p=None) + + # def test_extreme_p_values(self): + # """Test edge cases for p values.""" + # batch_size, vocab_size = 16, 1024 + # logits = torch.randn( + # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + # ) + + # # p close to 0 (very restrictive) + # p = torch.full((batch_size,), 0.01, dtype=torch.float32) + # self._compare_results(logits.clone(), k=None, p=p) + + # # p=1.0 (keep all) + # p = torch.ones(batch_size, dtype=torch.float32) + # self._compare_results(logits.clone(), k=None, p=p) + + # # Mixed values + # p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32) + # self._compare_results(logits.clone(), k=None, p=p) + + # def test_large_batch(self): + # """Test with a large batch size.""" + # batch_size, vocab_size = 512, 32000 + # logits = torch.randn( + # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + # ) + # k = torch.randint(1, 50, (batch_size,), generator=self.generator) + # p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5 + + # self._compare_results(logits, k, p) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index ac759944bc80..5290169b4f12 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -20,299 +20,526 @@ from vllm.triton_utils import tl, triton +_TRITON_TABLE_CACHE: dict[ + tuple[torch.device, torch.dtype], (torch.Tensor, torch.Tensor) +] = {} +_TRITON_BUFFER_CACHE: dict[ + tuple[torch.device, torch.dtype, int, int], torch.Tensor +] = {} + +_NORMAL_CDF_TO_SIGMA_TABLE = [ + 3.656, 3.650, 3.650, 3.650, 3.626, 3.626, 3.626, 3.514, 3.514, 3.503, + 3.503, 3.434, 3.434, 3.428, 3.428, 3.387, 3.380, 3.380, 3.376, 3.373, + 3.373, 3.356, 3.354, 3.354, 3.291, 3.249, 3.234, 3.214, 3.198, 3.198, + 3.185, 3.177, 3.177, 3.165, 3.164, 3.161, 3.138, 3.120, 3.115, 3.113, + 3.093, 3.066, 3.054, 3.043, 3.037, 3.023, 2.993, 2.991, 2.976, 2.970, + 2.952, 2.946, 2.932, 2.908, 2.902, 2.895, 2.886, 2.874, 2.861, 2.844, + 2.836, 2.810, 2.801, 2.790, 2.784, 2.779, 2.767, 2.757, 2.745, 2.733, + 2.723, 2.716, 2.693, 2.678, 2.671, 2.656, 2.649, 2.629, 2.611, 2.595, + 2.592, 2.585, 2.574, 2.550, 2.543, 2.534, 2.521, 2.518, 2.497, 2.485, + 2.468, 2.450, 2.441, 2.430, 2.412, 2.402, 2.389, 2.383, 2.377, 2.364, + 2.349, 2.338, 2.332, 2.319, 2.310, 2.301, 2.282, 2.274, 2.266, 2.250, + 2.242, 2.236, 2.226, 2.215, 2.207, 2.196, 2.179, 2.171, 2.162, 2.147, + 2.135, 2.121, 2.109, 2.095, 2.085, 2.073, 2.063, 2.045, 2.030, 2.016, + 2.003, 1.992, 1.983, 1.972, 1.960, 1.949, 1.940, 1.928, 1.912, 1.897, + 1.881, 1.869, 1.854, 1.838, 1.824, 1.807, 1.792, 1.779, 1.764, 1.751, + 1.739, 1.726, 1.711, 1.697, 1.685, 1.668, 1.652, 1.636, 1.622, 1.603, + 1.585, 1.568, 1.551, 1.534, 1.513, 1.499, 1.480, 1.464, 1.441, 1.422, + 1.394, 1.373, 1.347, 1.320, 1.296, 1.270, 1.246, 1.219, 1.190, 1.163, + 1.135, 1.104, 1.073, 1.041, 1.006, 0.969, 0.931, 0.894, 0.851, 0.806, + 0.757, 0.702, 0.643, 0.574, 0.498, 0.405, 0.288, 0.134, -0.110, -3.813 +] + +_PERCENTILE_TO_STD_TABLE = [ + 2.576, 2.319, 2.178, 2.064, 1.968, 1.892, 1.819, 1.757, 1.708, 1.659, + 1.616, 1.568, 1.526, 1.492, 1.456, 1.420, 1.382, 1.342, 1.309, 1.280, + 1.249, 1.221, 1.193, 1.169, 1.145, 1.121, 1.095, 1.073, 1.050, 1.030, + 1.008, 0.987, 0.966, 0.945, 0.926, 0.910, 0.891, 0.871, 0.854, 0.837, + 0.819, 0.803, 0.784, 0.767, 0.753, 0.734, 0.719, 0.702, 0.690, 0.675, + 0.658, 0.640, 0.625, 0.609, 0.595, 0.578, 0.564, 0.550, 0.537, 0.521, + 0.509, 0.495, 0.481, 0.466, 0.453, 0.439, 0.424, 0.410, 0.397, 0.383, + 0.370, 0.356, 0.343, 0.330, 0.316, 0.302, 0.289, 0.274, 0.261, 0.247, + 0.235, 0.223, 0.209, 0.196, 0.184, 0.172, 0.159, 0.149, 0.137, 0.124, + 0.112, 0.100, 0.086, 0.074, 0.062, 0.050, 0.035, 0.023, 0.009, -0.003, + -0.015, -0.027, -0.039, -0.052, -0.063, -0.074, -0.085, -0.097, -0.109, -0.122, + -0.134, -0.147, -0.158, -0.171, -0.184, -0.196, -0.210, -0.223, -0.235, -0.248, + -0.261, -0.275, -0.289, -0.302, -0.317, -0.328, -0.341, -0.353, -0.368, -0.382, + -0.396, -0.410, -0.426, -0.439, -0.452, -0.465, -0.480, -0.493, -0.507, -0.521, + -0.537, -0.551, -0.568, -0.582, -0.597, -0.614, -0.628, -0.643, -0.658, -0.673, + -0.691, -0.706, -0.721, -0.738, -0.754, -0.769, -0.789, -0.808, -0.824, -0.838, + -0.857, -0.877, -0.893, -0.912, -0.929, -0.947, -0.965, -0.983, -1.003, -1.027, + -1.050, -1.070, -1.092, -1.117, -1.139, -1.162, -1.189, -1.216, -1.241, -1.272, + -1.300, -1.330, -1.367, -1.404, -1.441, -1.485, -1.523, -1.564, -1.607, -1.658, + -1.710, -1.778, -1.832, -1.901, -1.978, -2.068, -2.174, -2.325, -2.577, -3.813 +] +# fmt: on + @triton.jit def _topk_topp_kernel( - # Input/output - logits_ptr, - # Parameters per row - k_ptr, - p_ptr, - # Dimensions - logits_stride: tl.constexpr, - vocab_size: tl.constexpr, - # Mask value - mask_value: tl.constexpr, - # Algorithm parameters + LOGITS, + BUFFER, + PERCENTILE_TO_STD_TABLE, + NORMAL_CDF_TO_SIGMA_TABLE, + K, + P, + BATCH_SIZE, + VOCAB_SIZE: tl.constexpr, + MASK_VALUE: tl.constexpr, BLOCK_SIZE: tl.constexpr, - K_ITERS: tl.constexpr, - P_ITERS: tl.constexpr, - # Feature flags (when False, use default values instead of loading) + BLOCK_SIZE_TRUNC: tl.constexpr, TOPK_ENABLED: tl.constexpr, TOPP_ENABLED: tl.constexpr, ): - """ - Combined top-k and top-p masking kernel. - - Applies top-k first (by logit value), then top-p (by probability). - Optimized to skip softmax computation when p >= 1.0. - """ - row_idx = tl.program_id(0) - - if TOPK_ENABLED: - k = tl.load(k_ptr + row_idx) - apply_topk = k < vocab_size - else: - # Default: keep all (no top-k filtering) - k = vocab_size - apply_topk = False - - if TOPP_ENABLED: - p = tl.load(p_ptr + row_idx) - apply_topp = p < 1.0 - else: - # Default: keep all (no top-p filtering) - p = 1.0 - apply_topp = False - - # Early exit if nothing to do - if (not apply_topk) and (not apply_topp): - return - - row_ptr = logits_ptr + row_idx * logits_stride - - # ========================================================================= - # Phase 1: Find top-k threshold using binary search on logits - # OPTIMIZATION: Fuse min/max finding with first binary search iteration - # by counting values > 0 during min/max pass (saves 1 memory pass) - # ========================================================================= - - topk_threshold = float("-inf") - - if apply_topk: - # Fused pass: find min/max AND count values > 0 (first binary search step) - max_logit = float("-inf") - min_logit = float("inf") - count_above_zero = tl.zeros([1], dtype=tl.int32) - - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - max_logit = tl.maximum(max_logit, tl.max(vals)) - vals_min = tl.where(mask, vals, float("inf")) - min_logit = tl.minimum(min_logit, tl.min(vals_min)) - # Count values > 0 (fused first binary search iteration) - count_above_zero += tl.sum((vals > 0.0).to(tl.int32)) - - # Use count_above_zero to set initial bounds (equivalent to first iteration) - # If count_above_zero >= k, the k-th largest is > 0, so raise lo to 0 - # Otherwise, the k-th largest is <= 0, so lower hi to 0 - if tl.sum(count_above_zero) >= k: - lo = 0.0 - hi = max_logit - else: - lo = min_logit - hi = 0.0 - - # Continue with remaining K_ITERS-1 binary search iterations - for _ in range(K_ITERS - 1): - mid = (lo + hi) * 0.5 - count_gt = tl.zeros([1], dtype=tl.int32) - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - count_gt += tl.sum((vals > mid).to(tl.int32)) - if tl.sum(count_gt) >= k: - lo = mid + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + for row_id in tl.range(pid, BATCH_SIZE, num_programs): + p = tl.load(P + row_id) + k = tl.load(K + row_id) + if p < 1.0 or k != VOCAB_SIZE: + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + BUFFER_ROW = BUFFER + pid * VOCAB_SIZE + + max_logit = -float("inf") + min_logit = float("inf") + + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + num_valid = tl.sum(mask_n) + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk0) / num_valid + sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + std_logit = tl.maximum(std_logit, 0.0) + + # Calculate outlier pivot t for Gaussian sigma-truncation + percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) + percentile = tl.minimum(percentile, 199) + sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) + sigma = sigma + tl.abs(sigma) * -0.2 + outlier_pivot = avg_logit + std_logit * sigma + num_outliers = tl.zeros((), dtype=tl.uint32) + + # First pass: compute max and min logits and gather outliers + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + outlier_mask = (logits_blk > outlier_pivot) & mask_n + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) + + # Second passes: Quaternary search for pivots (nlog_4(batch_size)) + num_iters = 0 + k_pivot = float("inf") + k_pivots_num = tl.zeros((), dtype=tl.uint32) + min_larger = float("inf") + num_min_larger = tl.zeros((), dtype=tl.uint32) + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 + ) + while k_pivot == float("inf"): + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + min_larger_2 = float("inf") + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate k_pivots_num and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") + ) + + k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) + + min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) + min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") + ) + + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-12) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-12) + num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-12) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k: + if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + if k_pivots_num_1 >= k: + if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + if k_pivots_num_2 >= k: + if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: + k_pivot = k_pivot_2 + k_pivots_num = k_pivots_num_2 + min_larger = min_larger_2 + num_min_larger = num_min_larger_2 + + # Update range + if k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 + + num_iters += 1 + if num_iters >= 20 or tl.abs(min_range - max_range) < 1e-12: + k_pivot = (max_range + min_range) / 2.0 else: - hi = mid - - # Refine to exact k-th largest value. - # After binary search: lo < k-th value <= hi (approximately). - # Find the actual logit values at these boundaries. - count_gt_lo = tl.zeros([1], dtype=tl.int32) - min_above_lo = float("inf") - max_at_or_below_hi = float("-inf") - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - count_gt_lo += tl.sum((vals > lo).to(tl.int32)) - vals_above_lo = tl.where(vals > lo, vals, float("inf")) - min_above_lo = tl.minimum(min_above_lo, tl.min(vals_above_lo)) - vals_at_or_below_hi = tl.where(vals <= hi, vals, float("-inf")) - max_at_or_below_hi = tl.maximum( - max_at_or_below_hi, tl.max(vals_at_or_below_hi) + # If top-k outlier gathering failed, search whole logit space + max_range = max_logit + min_range = min_logit + while k_pivot == float("inf"): + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + min_larger_2 = float("inf") + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate k_pivots_num and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) + + k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) + + min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) + min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) + + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-12) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-12) + num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-12) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k: + if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + if k_pivots_num_1 >= k: + if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + if k_pivots_num_2 >= k: + if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: + k_pivot = k_pivot_2 + k_pivots_num = k_pivots_num_2 + min_larger = min_larger_2 + num_min_larger = num_min_larger_2 + + # Update range + if k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 + + num_iters += 1 + if num_iters >= 20 or tl.abs(min_range - max_range) < 1e-12: + k_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = min_larger + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - (k_pivots_num - k) + num_kept = tl.zeros((), dtype=tl.uint32) + + #### TOP-P SAMPLING #### + + min_logit = k_pivot + sum_exp_logits = 0.0 + num_outliers_2 = tl.zeros((), dtype=tl.uint32) + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 ) - if tl.sum(count_gt_lo) == k: - topk_threshold = min_above_lo - else: - topk_threshold = max_at_or_below_hi - - # ========================================================================= - # If no top-p, apply top-k mask and return early - # ========================================================================= - - if not apply_topp: - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - result = tl.where(vals >= topk_threshold, vals, mask_value) - tl.store(row_ptr + offs, result, mask=mask) - return - - # ========================================================================= - # Phase 2: Compute softmax using online softmax (single pass) - # ========================================================================= - # Online softmax computes max and exp_sum in one pass by rescaling - # the running sum when a new max is found. - # - # Key insight: We need to handle the case where softmax_max is -inf - # (no valid values seen yet). In this case, -inf - (-inf) = nan, - # so we must skip blocks with no valid values. - - softmax_max = float("-inf") - exp_sum = 0.0 - - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - # Apply top-k mask - vals = tl.where(vals >= topk_threshold, vals, float("-inf")) - - # Find block max - block_max = tl.max(vals) - - # Skip blocks with no valid values (all -inf) - # This avoids nan from -inf - (-inf) - if block_max > float("-inf"): - # Update running max and rescale sum if needed - new_max = tl.maximum(softmax_max, block_max) - - # Rescale previous sum: sum * exp(old_max - new_max) - # When softmax_max is -inf (first valid block), exp(-inf - finite) = 0, - # which is correct since exp_sum starts at 0. - exp_sum = exp_sum * tl.exp(softmax_max - new_max) - softmax_max = new_max - - # Add current block's contribution (normalized by new max) - exp_sum += tl.sum(tl.exp(vals - softmax_max)) - - log_exp_sum = tl.log(exp_sum) - - # ========================================================================= - # Phase 3: Find top-p threshold using binary search on probabilities - # OPTIMIZATION: Fuse min/max finding with first binary search iteration - # by computing prob mass > 0.5 during min/max pass (saves 1 memory pass) - # ========================================================================= - - # Fused pass: find min/max log-probs AND sum probs > 0.5 (first iteration) - max_log_prob = float("-inf") - min_log_prob = float("inf") - log_half = -0.6931471805599453 # log(0.5) - prob_sum_above_half = 0.0 - - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - - # Only consider top-k values - is_topk = vals >= topk_threshold - - # log_prob = logit - softmax_max - log(exp_sum) - log_probs = vals - softmax_max - log_exp_sum - - log_probs_masked = tl.where(is_topk, log_probs, float("-inf")) - max_log_prob = tl.maximum(max_log_prob, tl.max(log_probs_masked)) - - log_probs_for_min = tl.where(is_topk & mask, log_probs, float("inf")) - min_log_prob = tl.minimum(min_log_prob, tl.min(log_probs_for_min)) - - # Sum probability mass above 0.5 (fused first binary search iteration) - probs = tl.exp(log_probs) - above_half = (log_probs > log_half) & is_topk - prob_sum_above_half += tl.sum(tl.where(above_half, probs, 0.0)) - - # Use prob_sum_above_half to set initial bounds (equivalent to first iteration) - if prob_sum_above_half >= p: - lo_lp = log_half - hi_lp = max_log_prob - else: - lo_lp = min_log_prob - hi_lp = log_half - - # Continue with remaining P_ITERS-1 binary search iterations - for _ in range(P_ITERS - 1): - mid_lp = (lo_lp + hi_lp) * 0.5 - - # Sum probabilities strictly > mid_lp - prob_sum_gt = 0.0 - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - - is_topk = vals >= topk_threshold - log_probs = vals - softmax_max - log_exp_sum - probs = tl.exp(log_probs) - - # Only sum probs that are strictly > threshold and in top-k - above_threshold = (log_probs > mid_lp) & is_topk - prob_sum_gt += tl.sum(tl.where(above_threshold, probs, 0.0)) - - # If sum of probs strictly above mid >= p, raise threshold - if prob_sum_gt >= p: - lo_lp = mid_lp - else: - hi_lp = mid_lp - - # Refine to exact threshold using combined approach (same as top-k). - # After binary search: prob_sum(> lo_lp) >= p, prob_sum(> hi_lp) < p. - # Count how many distinct log-probs are > lo_lp to determine which refinement. - count_gt_lo_lp = tl.zeros([1], dtype=tl.int32) - min_lp_above_lo = float("inf") - max_lp_at_or_below_hi = float("-inf") - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - - is_topk = vals >= topk_threshold - log_probs = vals - softmax_max - log_exp_sum - - above_lo = is_topk & (log_probs > lo_lp) - count_gt_lo_lp += tl.sum(above_lo.to(tl.int32)) - - lp_above_lo = tl.where(above_lo, log_probs, float("inf")) - min_lp_above_lo = tl.minimum(min_lp_above_lo, tl.min(lp_above_lo)) - - at_or_below_hi = is_topk & (log_probs <= hi_lp) - lp_at_or_below_hi = tl.where(at_or_below_hi, log_probs, float("-inf")) - max_lp_at_or_below_hi = tl.maximum( - max_lp_at_or_below_hi, tl.max(lp_at_or_below_hi) - ) - - # For top-p, use min if there are values > lo, otherwise use max. - # This handles edge cases where lo/hi converge to the same side. - if tl.sum(count_gt_lo_lp) > 0 and min_lp_above_lo < float("inf"): - topp_log_threshold = min_lp_above_lo - else: - topp_log_threshold = max_lp_at_or_below_hi - - # ========================================================================= - # Phase 4: Apply combined mask - # ========================================================================= - - for i in range(0, vocab_size, BLOCK_SIZE): - offs = i + tl.arange(0, BLOCK_SIZE) - mask = offs < vocab_size - vals = tl.load(row_ptr + offs, mask=mask, other=float("-inf")) - - # Apply top-k mask - keep = vals >= topk_threshold - - # Apply top-p mask - log_probs = vals - softmax_max - log_exp_sum - keep = keep & (log_probs >= topp_log_threshold) - - result = tl.where(keep, vals, mask_value) - tl.store(row_ptr + offs, result, mask=mask) - + # Third pass: Calculate exp logits and sum, gather top-k outliers + if num_outliers > k: + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + + # Duplicate logit handling + if num_keep < num_duplicate_logit: + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-12 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + else: + # If top-k outlier gathering failed, retry gathering using top-k pivot + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n + + # Duplicate logit handling + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-12 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers_2, tl.int32) + num_outliers_2 += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + search_range = tl.cast(num_outliers_2, tl.int32) + search_iters = tl.cast( + (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + + + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Fifth passes: Search for p_pivot + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) + + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-12) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-12) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-12) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-12 or num_iters >= 20: + p_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + + # Sixth pass: Apply mask + + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + keep_mask = (logits_blk > p_pivot) & mask_n + + # Duplicate logit handling + if num_keep < num_duplicate_logit: + duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-12) & mask_n + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + num_kept += tl.sum(duplicate_keep_mask) + keep_mask = keep_mask & (~duplicate_remove_mask) + + logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) def apply_top_k_top_p_triton( logits: torch.Tensor, @@ -327,9 +554,9 @@ def apply_top_k_top_p_triton( to the remaining k values (by probability). Args: - logits: [n, vocab_size] float32 tensor, modified in-place - k: [n] int32 tensor of top-k values per row, or None to disable top-k - p: [n] float32 tensor of top-p values per row (0 to 1), + logits: [batch_size, vocab_size] float32 tensor, modified in-place + k: [batch_size] int32 tensor of top-k values per row, or None to disable top-k + p: [batch_size] float32 tensor of top-p values per row (0 to 1), or None to disable top-p mask_value: Value for masked positions (default: -inf) @@ -340,42 +567,66 @@ def apply_top_k_top_p_triton( assert logits.dtype == torch.float32 assert logits.is_cuda - n, vocab_size = logits.shape + batch_size, vocab_size = logits.shape + num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count + NUM_PROGRAMS = min(num_sm, batch_size) topk_enabled = k is not None topp_enabled = p is not None - if n == 0 or not (topk_enabled or topp_enabled): + if batch_size == 0 or not (topk_enabled or topp_enabled): return logits if k is not None: - assert k.ndim == 1 and k.shape[0] == n and k.is_cuda + assert k.ndim == 1 and k.shape[0] == batch_size and k.is_cuda k_ptr = k.to(torch.int32) else: k_ptr = logits # Dummy pointer (won't be read) if p is not None: - assert p.ndim == 1 and p.shape[0] == n and p.is_cuda + assert p.ndim == 1 and p.shape[0] == batch_size and p.is_cuda p_ptr = p.to(torch.float32) else: p_ptr = logits # Dummy pointer (won't be read) - BLOCK_SIZE = 1024 - # K_ITERS must be large enough to distinguish adjacent logit values. - # With randn logits (range ~8), 20 iterations gives precision ~8/2^19 ≈ 1.5e-5 - K_ITERS = 18 - P_ITERS = 14 + num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count + NUM_PROGRAMS = min(num_sm, batch_size) + + buf_key = (logits.device, logits.dtype, NUM_PROGRAMS, vocab_size) + buffer = _TRITON_BUFFER_CACHE.get(buf_key) + if buffer is None or buffer.numel() < NUM_PROGRAMS * vocab_size: + buffer = torch.empty( + (NUM_PROGRAMS, vocab_size), device=logits.device, dtype=logits.dtype + ) + _TRITON_BUFFER_CACHE[buf_key] = buffer + + # Cache percentile table per device. + tbl_key = (logits.device, torch.float32) + tables = _TRITON_TABLE_CACHE.get(tbl_key) + if tables is None: + normal_cdf_to_sigma_table = torch.tensor( + _NORMAL_CDF_TO_SIGMA_TABLE, device=logits.device, dtype=torch.float32 + ) + percentile_to_std_table = torch.tensor( + _PERCENTILE_TO_STD_TABLE, device=logits.device, dtype=torch.float32 + ) + _TRITON_TABLE_CACHE[tbl_key] = (normal_cdf_to_sigma_table, percentile_to_std_table) + else: + normal_cdf_to_sigma_table, percentile_to_std_table = tables + - _topk_topp_kernel[(n,)]( + _topk_topp_kernel[(NUM_PROGRAMS,)]( logits, + buffer, + percentile_to_std_table, + normal_cdf_to_sigma_table, k_ptr, p_ptr, - logits_stride=logits.stride(0), - vocab_size=vocab_size, - mask_value=mask_value, - BLOCK_SIZE=BLOCK_SIZE, - K_ITERS=K_ITERS, - P_ITERS=P_ITERS, + BATCH_SIZE=batch_size, + MASK_VALUE=mask_value, + VOCAB_SIZE=vocab_size, + BLOCK_SIZE=8192, + BLOCK_SIZE_TRUNC=4096, TOPK_ENABLED=topk_enabled, TOPP_ENABLED=topp_enabled, ) From 9a5f30d7b8af59b723d7bc072c2dee4468c95c15 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 17:55:12 -0800 Subject: [PATCH 80/99] Copied topk + topp impl Signed-off-by: js_park --- tests/v1/sample/test_topk_topp_sampler.py | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index 6c19bc179327..5476cf1cf961 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -186,21 +186,21 @@ def _compare_results( f"(max diff {max_diff} values out of {max_kept})" ) - # @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) - # @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) - # def test_topk_only(self, batch_size: int, vocab_size: int): - # """Test top-k only (p=None).""" - # logits = torch.randn( - # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - # ) - # k = torch.randint( - # 1, min(100, vocab_size), (batch_size,), generator=self.generator - # ) - # # Randomly disable top-k for some rows (~25%) - # disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 - # k.masked_fill_(disable_mask, vocab_size) + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + def test_topk_only(self, batch_size: int, vocab_size: int): + """Test top-k only (p=None).""" + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + k = torch.randint( + 1, min(100, vocab_size), (batch_size,), generator=self.generator + ) + # Randomly disable top-k for some rows (~25%) + disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + k.masked_fill_(disable_mask, vocab_size) - # self._compare_results(logits, k, p=None) + self._compare_results(logits, k, p=None) # @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) # @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) From 65874cce90efcb9744d41bd1699f4a74c8cdcc91 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 18:52:58 -0800 Subject: [PATCH 81/99] Topp wrong Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_triton.py | 1128 +++++++++++++++--------- 1 file changed, 698 insertions(+), 430 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 5290169b4f12..8f60febfec17 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -3,16 +3,9 @@ """ Combined Top-K and Top-P Triton kernels. -These kernels apply top-k filtering first, then top-p on the remaining values. -This is more efficient than sorting the entire vocabulary. +Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs +using Pivot-based Truncation and Selection" By Park et al. -Algorithm: -1. Find k-th largest logit using binary search → top-k threshold -2. Mask logits below threshold, compute softmax (only k values contribute) -3. Find probability threshold for top-p using binary search -4. Apply final mask - -Complexity: O(vocab_size * (k_iters + p_iters)) where iters ≈ 16-20 """ import torch @@ -94,452 +87,727 @@ def _topk_topp_kernel( pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, BATCH_SIZE, num_programs): - p = tl.load(P + row_id) - k = tl.load(K + row_id) - if p < 1.0 or k != VOCAB_SIZE: - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - BUFFER_ROW = BUFFER + pid * VOCAB_SIZE - - max_logit = -float("inf") - min_logit = float("inf") - - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - num_valid = tl.sum(mask_n) - logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk0) / num_valid - sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - std_logit = tl.maximum(std_logit, 0.0) - - # Calculate outlier pivot t for Gaussian sigma-truncation - percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) - percentile = tl.minimum(percentile, 199) - sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) - sigma = sigma + tl.abs(sigma) * -0.2 - outlier_pivot = avg_logit + std_logit * sigma - num_outliers = tl.zeros((), dtype=tl.uint32) - - # First pass: compute max and min logits and gather outliers - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) - - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - - outlier_mask = (logits_blk > outlier_pivot) & mask_n - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += tl.sum(outlier_mask) - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) - - # Second passes: Quaternary search for pivots (nlog_4(batch_size)) - num_iters = 0 - k_pivot = float("inf") - k_pivots_num = tl.zeros((), dtype=tl.uint32) - min_larger = float("inf") - num_min_larger = tl.zeros((), dtype=tl.uint32) - if num_outliers > k: - max_range = max_logit - min_range = outlier_pivot - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 - ) - while k_pivot == float("inf"): - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - min_larger_0 = float("inf") - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - min_larger_1 = float("inf") - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - min_larger_2 = float("inf") - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - # First pass: Calculate k_pivots_num and min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - logits_blk2 = tl.load( - BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") - ) - k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + BUFFER_ROW = BUFFER + pid * VOCAB_SIZE + + final_pivot = -float("inf") + duplicate_logit = float("inf") + num_duplicate_logit = tl.zeros((), dtype=tl.uint32) + num_keep = tl.zeros((), dtype=tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + max_logit = -float('inf') + min_logit = float('inf') + + if TOPK_ENABLED: + k = tl.load(K + row_id) + if k != VOCAB_SIZE: + + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + num_valid = tl.sum(mask_n) + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk0) / num_valid + sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + std_logit = tl.maximum(std_logit, 0.0) + + # Calculate outlier pivot t for Gaussian sigma-truncation + percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) + percentile = tl.minimum(percentile, 199) + sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) + sigma = sigma + tl.abs(sigma) * -0.25 + outlier_pivot = avg_logit + std_logit * sigma + num_outliers = tl.zeros((), dtype=tl.uint32) + + # First pass: compute max and min logits and gather outliers + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) - min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) - min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) - min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - # Second pass: Calculate num_min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - logits_blk2 = tl.load( - BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") + outlier_mask = (logits_blk > outlier_pivot) & mask_n + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) + + # Second passes: Quaternary search for pivots (nlog_4(batch_size)) + num_iters = 0 + k_pivot = float("inf") + k_pivots_num = tl.zeros((), dtype=tl.uint32) + min_larger = float("inf") + num_min_larger = tl.zeros((), dtype=tl.uint32) + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 + ) + while k_pivot == float("inf"): + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + min_larger_2 = float("inf") + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate k_pivots_num and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") + ) + + k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) + + min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) + min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") + ) + + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k: + if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + if k_pivots_num_1 >= k: + if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + if k_pivots_num_2 >= k: + if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: + k_pivot = k_pivot_2 + k_pivots_num = k_pivots_num_2 + min_larger = min_larger_2 + num_min_larger = num_min_larger_2 + + # Update range + if k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 + + num_iters += 1 + if num_iters >= 24 or tl.abs(min_range - max_range) < 1e-15: + k_pivot = (max_range + min_range) / 2.0 + else: + # If top-k outlier gathering failed, search whole logit space + max_range = max_logit + min_range = min_logit + while k_pivot == float("inf"): + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + min_larger_2 = float("inf") + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate k_pivots_num and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) + + k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) + + min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) + min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) + + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k: + if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + if k_pivots_num_1 >= k: + if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + if k_pivots_num_2 >= k: + if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: + k_pivot = k_pivot_2 + k_pivots_num = k_pivots_num_2 + min_larger = min_larger_2 + num_min_larger = num_min_larger_2 + + # Update range + if k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 + + num_iters += 1 + if num_iters >= 24 or tl.abs(min_range - max_range) < 1e-15: + k_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = min_larger + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - (k_pivots_num - k) + num_kept = tl.zeros((), dtype=tl.uint32) + + if TOPP_ENABLED: + #### TOP-P SAMPLING AFTER TOP-K #### + p = tl.load(P + row_id) + if p < 1.0: + min_logit = k_pivot + sum_exp_logits = 0.0 + num_outliers_2 = tl.zeros((), dtype=tl.uint32) + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 ) - num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-12) - num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-12) - num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-12) - - # Check if any of the pivots satisfy termination condition - if k_pivots_num_0 >= k: - if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: - k_pivot = k_pivot_0 - k_pivots_num = k_pivots_num_0 - min_larger = min_larger_0 - num_min_larger = num_min_larger_0 - if k_pivots_num_1 >= k: - if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: - k_pivot = k_pivot_1 - k_pivots_num = k_pivots_num_1 - min_larger = min_larger_1 - num_min_larger = num_min_larger_1 - if k_pivots_num_2 >= k: - if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: - k_pivot = k_pivot_2 - k_pivots_num = k_pivots_num_2 - min_larger = min_larger_2 - num_min_larger = num_min_larger_2 - - # Update range - if k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: - min_range = k_pivot_1 - elif k_pivots_num_0 > k: - min_range = k_pivot_0 - - if k_pivots_num_0 < k: - max_range = k_pivot_0 - elif k_pivots_num_1 < k: - max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - - num_iters += 1 - if num_iters >= 20 or tl.abs(min_range - max_range) < 1e-12: - k_pivot = (max_range + min_range) / 2.0 - else: - # If top-k outlier gathering failed, search whole logit space - max_range = max_logit - min_range = min_logit - while k_pivot == float("inf"): - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - min_larger_0 = float("inf") - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - min_larger_1 = float("inf") - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - min_larger_2 = float("inf") - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - # First pass: Calculate k_pivots_num and min_larger + # Third pass: Calculate exp logits and sum, gather top-k outliers + if num_outliers > k: + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + + # Duplicate logit handling + if num_keep < num_duplicate_logit: + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-15 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + else: + # If top-k outlier gathering failed, retry gathering using top-k pivot + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n + + # Duplicate logit handling + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-15 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers_2, tl.int32) + num_outliers_2 += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + search_range = tl.cast(num_outliers_2, tl.int32) + search_iters = tl.cast( + (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + + + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Fifth passes: Search for p_pivot + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) + + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-15 or num_iters >= 24: + p_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k + Top-p path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + else: + # Top-k only path + final_pivot = k_pivot + elif TOPP_ENABLED: + #### STANDALONE TOP-P SAMPLING #### + p = tl.load(P + row_id) + if p < 1.0: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + num_valid = tl.sum(mask_n) + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk0) / num_valid + sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + std_logit = tl.maximum(std_logit, 0.0) + max_sample = avg_logit + std_logit * 10.0 + sum_exp_logits = 0.0 + + # First pass: compute max and min logits and sum_exp_logits for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk2 = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=avg_logit) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + probs_blk = tl.exp(logits_blk - max_sample) + probs_blk = tl.where(mask_n, probs_blk, 0.0) + sum_exp_logits += tl.sum(probs_blk) - k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) + idx = tl.cast(p * 200, tl.int32) + idx = tl.maximum(0, tl.minimum(idx, 199)) + sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) + sigma = sigma + tl.abs(sigma) * -0.25 + outlier_pivot = avg_logit + std_logit * sigma - min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) - min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) - min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) + outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits + sum_outlier_probs = 0.0 + num_outliers = tl.zeros((), dtype=tl.uint32) - # Second pass: Calculate num_min_larger + # Second pass: Calculate softmax and gather outliers for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk2 = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) - num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-12) - num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-12) - num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-12) - - # Check if any of the pivots satisfy termination condition - if k_pivots_num_0 >= k: - if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: - k_pivot = k_pivot_0 - k_pivots_num = k_pivots_num_0 - min_larger = min_larger_0 - num_min_larger = num_min_larger_0 - if k_pivots_num_1 >= k: - if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: - k_pivot = k_pivot_1 - k_pivots_num = k_pivots_num_1 - min_larger = min_larger_1 - num_min_larger = num_min_larger_1 - if k_pivots_num_2 >= k: - if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: - k_pivot = k_pivot_2 - k_pivots_num = k_pivots_num_2 - min_larger = min_larger_2 - num_min_larger = num_min_larger_2 - - # Update range - if k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: - min_range = k_pivot_1 - elif k_pivots_num_0 > k: - min_range = k_pivot_0 + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) - if k_pivots_num_0 < k: - max_range = k_pivot_0 - elif k_pivots_num_1 < k: - max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - - num_iters += 1 - if num_iters >= 20 or tl.abs(min_range - max_range) < 1e-12: - k_pivot = (max_range + min_range) / 2.0 - - duplicate_logit = min_larger - num_duplicate_logit = num_min_larger - num_keep = num_duplicate_logit - (k_pivots_num - k) - num_kept = tl.zeros((), dtype=tl.uint32) - - #### TOP-P SAMPLING #### - - min_logit = k_pivot - sum_exp_logits = 0.0 - num_outliers_2 = tl.zeros((), dtype=tl.uint32) - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 - ) - - # Third pass: Calculate exp logits and sum, gather top-k outliers - if num_outliers > k: - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - - probs_blk = tl.load(BUFFER_ROW + offs_n, - mask=mask_n_2, - other=-float('inf')) - - outlier_mask = (probs_blk > min_logit) & mask_n_2 - # Duplicate logit handling - if num_keep < num_duplicate_logit: - duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-12 - duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask - duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask - outlier_mask = outlier_mask & (~duplicate_remove_mask) - num_kept += tl.sum(duplicate_keep_mask) - - probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) + max_range = tl.exp(max_logit - max_sample) / sum_exp_logits + min_range = tl.exp(min_logit - max_sample) / sum_exp_logits - # Fourth pass: Calculate BUFFER and get outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - - probs_blk = tl.load(BUFFER_ROW + offs_n, - mask=mask_n_2, - other=-float('inf')) + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 - outlier_mask = (probs_blk > min_logit) & mask_n_2 - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) - else: - # If top-k outlier gathering failed, retry gathering using top-k pivot + # Third pass: Search for p_pivot + if sum_outlier_probs > p: + min_range = outlier_prob + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) + + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) + + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-15 or num_iters >= 24: + p_pivot = (max_range + min_range) / 2.0 + else: + # Re-populate the buffer with full softmax probabilities + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) + + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-15 or num_iters >= 24: + p_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-p only path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample + + # Sixth pass: Apply mask + if final_pivot != -float("inf"): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - - probs_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) - - outlier_mask = (probs_blk > min_logit) & mask_n + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + keep_mask = (logits_blk > final_pivot) & mask_n # Duplicate logit handling - duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-12 - duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask - duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask - outlier_mask = outlier_mask & (~duplicate_remove_mask) - num_kept += tl.sum(duplicate_keep_mask) - - probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers_2, tl.int32) - num_outliers_2 += tl.sum(outlier_mask) - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + if num_keep < num_duplicate_logit: + duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-15) & mask_n + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + num_kept += tl.sum(duplicate_keep_mask) + keep_mask = keep_mask & (~duplicate_remove_mask) - search_range = tl.cast(num_outliers_2, tl.int32) - search_iters = tl.cast( - (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) - - # Fourth pass: Calculate BUFFER and get outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) - - - max_range = tl.exp(max_logit - max_logit) / sum_exp_logits - min_range = tl.exp(min_logit - max_logit) / sum_exp_logits - - p_pivot = 1.0 - num_iters = 0 - min_larger_prob = 1.0 - num_min_larger = tl.zeros((), dtype=tl.uint32) - p_pivots_sum = 0.0 - - # Fifth passes: Search for p_pivot - while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - p_pivots_sum_1 = 0.0 - min_larger_1 = 1.0 - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - - # First pass: Calculate p_pivots_sum and min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - - # Second pass: Calculate num_min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-12) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-12) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-12) - - # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - - # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: - min_range = p_pivot_1 - elif p_pivots_sum_0 > p: - min_range = p_pivot_0 - - if p_pivots_sum_0 < p: - max_range = p_pivot_0 - elif p_pivots_sum_1 < p: - max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 - - num_iters += 1 - if (max_range - min_range) < 1e-12 or num_iters >= 20: - p_pivot = (max_range + min_range) / 2.0 - - duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit - num_duplicate_logit = num_min_larger - num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) - num_kept = tl.zeros((), dtype=tl.uint32) - - p_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - - # Sixth pass: Apply mask - - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - keep_mask = (logits_blk > p_pivot) & mask_n - - # Duplicate logit handling - if num_keep < num_duplicate_logit: - duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-12) & mask_n - duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask - duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask - num_kept += tl.sum(duplicate_keep_mask) - keep_mask = keep_mask & (~duplicate_remove_mask) - - logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) def apply_top_k_top_p_triton( logits: torch.Tensor, From a671a0988f7469dd35a83c30ffed5b8be0481ca2 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 18:58:16 -0800 Subject: [PATCH 82/99] Topp working, topp only Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_triton.py | 958 +++++++------------------ 1 file changed, 265 insertions(+), 693 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 8f60febfec17..b22a53a296a2 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -100,714 +100,286 @@ def _topk_topp_kernel( max_logit = -float('inf') min_logit = float('inf') - if TOPK_ENABLED: - k = tl.load(K + row_id) - if k != VOCAB_SIZE: - - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - num_valid = tl.sum(mask_n) - logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk0) / num_valid - sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - std_logit = tl.maximum(std_logit, 0.0) - - # Calculate outlier pivot t for Gaussian sigma-truncation - percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) - percentile = tl.minimum(percentile, 199) - sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) - sigma = sigma + tl.abs(sigma) * -0.25 - outlier_pivot = avg_logit + std_logit * sigma - num_outliers = tl.zeros((), dtype=tl.uint32) - - # First pass: compute max and min logits and gather outliers + #### STANDALONE TOP-P SAMPLING #### + p = tl.load(P + row_id) + if p < 1.0: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + num_valid = tl.sum(mask_n) + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk0) / num_valid + sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + std_logit = tl.maximum(std_logit, 0.0) + max_sample = avg_logit + std_logit * 10.0 + sum_exp_logits = 0.0 + + # First pass: compute max and min logits and sum_exp_logits + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=avg_logit) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + probs_blk = tl.exp(logits_blk - max_sample) + probs_blk = tl.where(mask_n, probs_blk, 0.0) + sum_exp_logits += tl.sum(probs_blk) + + idx = tl.cast(p * 200, tl.int32) + idx = tl.maximum(0, tl.minimum(idx, 199)) + sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) + sigma = sigma + tl.abs(sigma) * -0.25 + outlier_pivot = avg_logit + std_logit * sigma + + outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits + sum_outlier_probs = 0.0 + num_outliers = tl.zeros((), dtype=tl.uint32) + + # Second pass: Calculate softmax and gather outliers + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + + max_range = tl.exp(max_logit - max_sample) / sum_exp_logits + min_range = tl.exp(min_logit - max_sample) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Third pass: Search for p_pivot + if sum_outlier_probs > p: + min_range = outlier_prob + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) + + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) + + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-15 or num_iters >= 24: + p_pivot = (max_range + min_range) / 2.0 + else: + # Re-populate the buffer with full softmax probabilities for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) - - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - - outlier_mask = (logits_blk > outlier_pivot) & mask_n - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += tl.sum(outlier_mask) - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) - - # Second passes: Quaternary search for pivots (nlog_4(batch_size)) - num_iters = 0 - k_pivot = float("inf") - k_pivots_num = tl.zeros((), dtype=tl.uint32) - min_larger = float("inf") - num_min_larger = tl.zeros((), dtype=tl.uint32) - if num_outliers > k: - max_range = max_logit - min_range = outlier_pivot - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 - ) - while k_pivot == float("inf"): - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - min_larger_0 = float("inf") - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - min_larger_1 = float("inf") - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - min_larger_2 = float("inf") - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - # First pass: Calculate k_pivots_num and min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - logits_blk2 = tl.load( - BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") - ) - - k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) - - min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) - min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) - min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) - - # Second pass: Calculate num_min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - logits_blk2 = tl.load( - BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") - ) - - num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-15) - - # Check if any of the pivots satisfy termination condition - if k_pivots_num_0 >= k: - if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: - k_pivot = k_pivot_0 - k_pivots_num = k_pivots_num_0 - min_larger = min_larger_0 - num_min_larger = num_min_larger_0 - if k_pivots_num_1 >= k: - if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: - k_pivot = k_pivot_1 - k_pivots_num = k_pivots_num_1 - min_larger = min_larger_1 - num_min_larger = num_min_larger_1 - if k_pivots_num_2 >= k: - if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: - k_pivot = k_pivot_2 - k_pivots_num = k_pivots_num_2 - min_larger = min_larger_2 - num_min_larger = num_min_larger_2 - - # Update range - if k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: - min_range = k_pivot_1 - elif k_pivots_num_0 > k: - min_range = k_pivot_0 - - if k_pivots_num_0 < k: - max_range = k_pivot_0 - elif k_pivots_num_1 < k: - max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - - num_iters += 1 - if num_iters >= 24 or tl.abs(min_range - max_range) < 1e-15: - k_pivot = (max_range + min_range) / 2.0 - else: - # If top-k outlier gathering failed, search whole logit space - max_range = max_logit - min_range = min_logit - while k_pivot == float("inf"): - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - min_larger_0 = float("inf") - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - min_larger_1 = float("inf") - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - min_larger_2 = float("inf") - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - # First pass: Calculate k_pivots_num and min_larger - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk2 = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) - - k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) - - min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) - min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) - min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) - - # Second pass: Calculate num_min_larger - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk2 = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") - ) - - num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-15) - - # Check if any of the pivots satisfy termination condition - if k_pivots_num_0 >= k: - if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: - k_pivot = k_pivot_0 - k_pivots_num = k_pivots_num_0 - min_larger = min_larger_0 - num_min_larger = num_min_larger_0 - if k_pivots_num_1 >= k: - if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: - k_pivot = k_pivot_1 - k_pivots_num = k_pivots_num_1 - min_larger = min_larger_1 - num_min_larger = num_min_larger_1 - if k_pivots_num_2 >= k: - if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: - k_pivot = k_pivot_2 - k_pivots_num = k_pivots_num_2 - min_larger = min_larger_2 - num_min_larger = num_min_larger_2 - - # Update range - if k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: - min_range = k_pivot_1 - elif k_pivots_num_0 > k: - min_range = k_pivot_0 - - if k_pivots_num_0 < k: - max_range = k_pivot_0 - elif k_pivots_num_1 < k: - max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - - num_iters += 1 - if num_iters >= 24 or tl.abs(min_range - max_range) < 1e-15: - k_pivot = (max_range + min_range) / 2.0 - - duplicate_logit = min_larger - num_duplicate_logit = num_min_larger - num_keep = num_duplicate_logit - (k_pivots_num - k) - num_kept = tl.zeros((), dtype=tl.uint32) - - if TOPP_ENABLED: - #### TOP-P SAMPLING AFTER TOP-K #### - p = tl.load(P + row_id) - if p < 1.0: - min_logit = k_pivot - sum_exp_logits = 0.0 - num_outliers_2 = tl.zeros((), dtype=tl.uint32) - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 - ) - - # Third pass: Calculate exp logits and sum, gather top-k outliers - if num_outliers > k: - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - - probs_blk = tl.load(BUFFER_ROW + offs_n, - mask=mask_n_2, - other=-float('inf')) - - outlier_mask = (probs_blk > min_logit) & mask_n_2 - - # Duplicate logit handling - if num_keep < num_duplicate_logit: - duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-15 - duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask - duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask - outlier_mask = outlier_mask & (~duplicate_remove_mask) - num_kept += tl.sum(duplicate_keep_mask) - - probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - - # Fourth pass: Calculate BUFFER and get outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - - probs_blk = tl.load(BUFFER_ROW + offs_n, - mask=mask_n_2, - other=-float('inf')) - - outlier_mask = (probs_blk > min_logit) & mask_n_2 - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) - else: - # If top-k outlier gathering failed, retry gathering using top-k pivot - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - - probs_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) - - outlier_mask = (probs_blk > min_logit) & mask_n - - # Duplicate logit handling - duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-15 - duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask - duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask - outlier_mask = outlier_mask & (~duplicate_remove_mask) - num_kept += tl.sum(duplicate_keep_mask) - - probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) - probs_blk = probs_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers_2, tl.int32) - num_outliers_2 += tl.sum(outlier_mask) - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) - - search_range = tl.cast(num_outliers_2, tl.int32) - search_iters = tl.cast( - (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) - - # Fourth pass: Calculate BUFFER and get outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) - - - max_range = tl.exp(max_logit - max_logit) / sum_exp_logits - min_range = tl.exp(min_logit - max_logit) / sum_exp_logits - - p_pivot = 1.0 - num_iters = 0 - min_larger_prob = 1.0 - num_min_larger = tl.zeros((), dtype=tl.uint32) - p_pivots_sum = 0.0 - - # Fifth passes: Search for p_pivot - while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - p_pivots_sum_1 = 0.0 - min_larger_1 = 1.0 - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - - # First pass: Calculate p_pivots_sum and min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - - # Second pass: Calculate num_min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) - - # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - - # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: - min_range = p_pivot_1 - elif p_pivots_sum_0 > p: - min_range = p_pivot_0 - - if p_pivots_sum_0 < p: - max_range = p_pivot_0 - elif p_pivots_sum_1 < p: - max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 - - num_iters += 1 - if (max_range - min_range) < 1e-15 or num_iters >= 24: - p_pivot = (max_range + min_range) / 2.0 - - duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit - num_duplicate_logit = num_min_larger - num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) - num_kept = tl.zeros((), dtype=tl.uint32) - - # Top-k + Top-p path - final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - else: - # Top-k only path - final_pivot = k_pivot - elif TOPP_ENABLED: - #### STANDALONE TOP-P SAMPLING #### - p = tl.load(P + row_id) - if p < 1.0: - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - num_valid = tl.sum(mask_n) - logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk0) / num_valid - sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - std_logit = tl.maximum(std_logit, 0.0) - max_sample = avg_logit + std_logit * 10.0 - sum_exp_logits = 0.0 - - # First pass: compute max and min logits and sum_exp_logits + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=avg_logit) - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - probs_blk = tl.exp(logits_blk - max_sample) - probs_blk = tl.where(mask_n, probs_blk, 0.0) - sum_exp_logits += tl.sum(probs_blk) + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - idx = tl.cast(p * 200, tl.int32) - idx = tl.maximum(0, tl.minimum(idx, 199)) - sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) - sigma = sigma + tl.abs(sigma) * -0.25 - outlier_pivot = avg_logit + std_logit * sigma + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits - sum_outlier_probs = 0.0 - num_outliers = tl.zeros((), dtype=tl.uint32) - # Second pass: Calculate softmax and gather outliers + # Second pass: Calculate num_min_larger for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - - probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = tl.exp(probs_blk - max_sample) - probs_blk = probs_blk / sum_exp_logits - - outlier_mask = (probs_blk > outlier_prob) & mask_n - sum_outlier_probs += tl.sum(outlier_mask * probs_blk) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += tl.sum(outlier_mask) - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) - - - max_range = tl.exp(max_logit - max_sample) / sum_exp_logits - min_range = tl.exp(min_logit - max_sample) / sum_exp_logits + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 - p_pivot = 1.0 - num_iters = 0 - min_larger_prob = 1.0 - num_min_larger = tl.zeros((), dtype=tl.uint32) - p_pivots_sum = 0.0 + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-15 or num_iters >= 24: + p_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-p only path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample + + # Sixth pass: Apply mask + if final_pivot != -float("inf"): + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + keep_mask = (logits_blk > final_pivot) & mask_n + + # Duplicate logit handling + if num_keep < num_duplicate_logit: + duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-15) & mask_n + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + num_kept += tl.sum(duplicate_keep_mask) + keep_mask = keep_mask & (~duplicate_remove_mask) - # Third pass: Search for p_pivot - if sum_outlier_probs > p: - min_range = outlier_prob - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) - - while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - p_pivots_sum_1 = 0.0 - min_larger_1 = 1.0 - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - - # First pass: Calculate p_pivots_sum and min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - - # Second pass: Calculate num_min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) - - # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - - # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: - min_range = p_pivot_1 - elif p_pivots_sum_0 > p: - min_range = p_pivot_0 - - if p_pivots_sum_0 < p: - max_range = p_pivot_0 - elif p_pivots_sum_1 < p: - max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 - - num_iters += 1 - if (max_range - min_range) < 1e-15 or num_iters >= 24: - p_pivot = (max_range + min_range) / 2.0 - else: - # Re-populate the buffer with full softmax probabilities - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - - probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = tl.exp(probs_blk - max_sample) - probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - - while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - p_pivots_sum_1 = 0.0 - min_larger_1 = 1.0 - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - - # First pass: Calculate p_pivots_sum and min_larger - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - - # Second pass: Calculate num_min_larger - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) - - # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - - # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: - min_range = p_pivot_1 - elif p_pivots_sum_0 > p: - min_range = p_pivot_0 - - if p_pivots_sum_0 < p: - max_range = p_pivot_0 - elif p_pivots_sum_1 < p: - max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 - - num_iters += 1 - if (max_range - min_range) < 1e-15 or num_iters >= 24: - p_pivot = (max_range + min_range) / 2.0 - - duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit - num_duplicate_logit = num_min_larger - num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) - num_kept = tl.zeros((), dtype=tl.uint32) - - # Top-p only path - final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample - - # Sixth pass: Apply mask - if final_pivot != -float("inf"): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - keep_mask = (logits_blk > final_pivot) & mask_n - - # Duplicate logit handling - if num_keep < num_duplicate_logit: - duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-15) & mask_n - duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask - duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask - num_kept += tl.sum(duplicate_keep_mask) - keep_mask = keep_mask & (~duplicate_remove_mask) - - logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) - tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) def apply_top_k_top_p_triton( logits: torch.Tensor, From cf6ab55a4221081d266898da59acbfde06eb324f Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 19:01:17 -0800 Subject: [PATCH 83/99] Both Topk Topp working Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_triton.py | 914 ++++++++++++++++++------- 1 file changed, 671 insertions(+), 243 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index b22a53a296a2..b52404fddd53 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -100,165 +100,474 @@ def _topk_topp_kernel( max_logit = -float('inf') min_logit = float('inf') - #### STANDALONE TOP-P SAMPLING #### - p = tl.load(P + row_id) - if p < 1.0: - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - num_valid = tl.sum(mask_n) - logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk0) / num_valid - sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - std_logit = tl.maximum(std_logit, 0.0) - max_sample = avg_logit + std_logit * 10.0 - sum_exp_logits = 0.0 - - # First pass: compute max and min logits and sum_exp_logits - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=avg_logit) - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - - probs_blk = tl.exp(logits_blk - max_sample) - probs_blk = tl.where(mask_n, probs_blk, 0.0) - sum_exp_logits += tl.sum(probs_blk) - - idx = tl.cast(p * 200, tl.int32) - idx = tl.maximum(0, tl.minimum(idx, 199)) - sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) - sigma = sigma + tl.abs(sigma) * -0.25 - outlier_pivot = avg_logit + std_logit * sigma - - outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits - sum_outlier_probs = 0.0 - num_outliers = tl.zeros((), dtype=tl.uint32) + if TOPK_ENABLED: + k = tl.load(K + row_id) + if k != VOCAB_SIZE: + + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + num_valid = tl.sum(mask_n) + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk0) / num_valid + sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + std_logit = tl.maximum(std_logit, 0.0) + + # Calculate outlier pivot t for Gaussian sigma-truncation + percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) + percentile = tl.minimum(percentile, 199) + sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) + sigma = sigma + tl.abs(sigma) * -0.25 + outlier_pivot = avg_logit + std_logit * sigma + num_outliers = tl.zeros((), dtype=tl.uint32) + + # First pass: compute max and min logits and gather outliers + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + outlier_mask = (logits_blk > outlier_pivot) & mask_n + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) + + # Second passes: Quaternary search for pivots (nlog_4(batch_size)) + num_iters = 0 + k_pivot = float("inf") + k_pivots_num = tl.zeros((), dtype=tl.uint32) + min_larger = float("inf") + num_min_larger = tl.zeros((), dtype=tl.uint32) + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 + ) + while k_pivot == float("inf"): + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + min_larger_2 = float("inf") + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate k_pivots_num and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") + ) + + k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) + + min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) + min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") + ) + + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k: + if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + if k_pivots_num_1 >= k: + if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + if k_pivots_num_2 >= k: + if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: + k_pivot = k_pivot_2 + k_pivots_num = k_pivots_num_2 + min_larger = min_larger_2 + num_min_larger = num_min_larger_2 + + # Update range + if k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 + + num_iters += 1 + if num_iters >= 24 or tl.abs(min_range - max_range) < 1e-15: + k_pivot = (max_range + min_range) / 2.0 + else: + # If top-k outlier gathering failed, search whole logit space + max_range = max_logit + min_range = min_logit + while k_pivot == float("inf"): + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) + min_larger_2 = float("inf") + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate k_pivots_num and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) + + k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) + k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) + + min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) + min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) + + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k: + if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + if k_pivots_num_1 >= k: + if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + if k_pivots_num_2 >= k: + if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: + k_pivot = k_pivot_2 + k_pivots_num = k_pivots_num_2 + min_larger = min_larger_2 + num_min_larger = num_min_larger_2 + + # Update range + if k_pivots_num_2 > k: + min_range = k_pivot_2 + elif k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + elif k_pivots_num_2 < k: + max_range = k_pivot_2 + + num_iters += 1 + if num_iters >= 24 or tl.abs(min_range - max_range) < 1e-15: + k_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = min_larger + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - (k_pivots_num - k) + num_kept = tl.zeros((), dtype=tl.uint32) + + if TOPP_ENABLED: + #### TOP-P SAMPLING AFTER TOP-K #### + p = tl.load(P + row_id) + if p < 1.0: + min_logit = k_pivot + sum_exp_logits = 0.0 + num_outliers_2 = tl.zeros((), dtype=tl.uint32) + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 + ) + + # Third pass: Calculate exp logits and sum, gather top-k outliers + if num_outliers > k: + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + + # Duplicate logit handling + if num_keep < num_duplicate_logit: + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-15 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + else: + # If top-k outlier gathering failed, retry gathering using top-k pivot + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=-float('inf')) + + outlier_mask = (probs_blk > min_logit) & mask_n + + # Duplicate logit handling + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-15 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers_2, tl.int32) + num_outliers_2 += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + search_range = tl.cast(num_outliers_2, tl.int32) + search_iters = tl.cast( + (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + + + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Fifth passes: Search for p_pivot + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) + + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-15 or num_iters >= 24: + p_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k + Top-p path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + else: + # Top-k only path + final_pivot = k_pivot + elif TOPP_ENABLED: + #### STANDALONE TOP-P SAMPLING #### + p = tl.load(P + row_id) + if p < 1.0: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + num_valid = tl.sum(mask_n) + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) + avg_logit = tl.sum(logits_blk0) / num_valid + sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid + std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) + std_logit = tl.maximum(std_logit, 0.0) + max_sample = avg_logit + std_logit * 10.0 + sum_exp_logits = 0.0 + + # First pass: compute max and min logits and sum_exp_logits + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, + mask=mask_n, + other=avg_logit) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + + probs_blk = tl.exp(logits_blk - max_sample) + probs_blk = tl.where(mask_n, probs_blk, 0.0) + sum_exp_logits += tl.sum(probs_blk) - # Second pass: Calculate softmax and gather outliers - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE + idx = tl.cast(p * 200, tl.int32) + idx = tl.maximum(0, tl.minimum(idx, 199)) + sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) + sigma = sigma + tl.abs(sigma) * -0.25 + outlier_pivot = avg_logit + std_logit * sigma - probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) - probs_blk = tl.exp(probs_blk - max_sample) - probs_blk = probs_blk / sum_exp_logits - - outlier_mask = (probs_blk > outlier_prob) & mask_n - sum_outlier_probs += tl.sum(outlier_mask * probs_blk) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) - num_outliers += tl.sum(outlier_mask) - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) - + outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits + sum_outlier_probs = 0.0 + num_outliers = tl.zeros((), dtype=tl.uint32) - max_range = tl.exp(max_logit - max_sample) / sum_exp_logits - min_range = tl.exp(min_logit - max_sample) / sum_exp_logits - - p_pivot = 1.0 - num_iters = 0 - min_larger_prob = 1.0 - num_min_larger = tl.zeros((), dtype=tl.uint32) - p_pivots_sum = 0.0 - - # Third pass: Search for p_pivot - if sum_outlier_probs > p: - min_range = outlier_prob - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) - - while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - p_pivots_sum_1 = 0.0 - min_larger_1 = 1.0 - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - - # First pass: Calculate p_pivots_sum and min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - - # Second pass: Calculate num_min_larger - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) - mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) - - # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - - # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: - min_range = p_pivot_1 - elif p_pivots_sum_0 > p: - min_range = p_pivot_0 - - if p_pivots_sum_0 < p: - max_range = p_pivot_0 - elif p_pivots_sum_1 < p: - max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 - - num_iters += 1 - if (max_range - min_range) < 1e-15 or num_iters >= 24: - p_pivot = (max_range + min_range) / 2.0 - else: - # Re-populate the buffer with full softmax probabilities + # Second pass: Calculate softmax and gather outliers for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE @@ -266,100 +575,219 @@ def _topk_topp_kernel( probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) probs_blk = tl.exp(probs_blk - max_sample) probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - - while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - p_pivots_sum_0 = 0.0 - min_larger_0 = 1.0 - num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - p_pivots_sum_1 = 0.0 - min_larger_1 = 1.0 - num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + - # First pass: Calculate p_pivots_sum and min_larger + max_range = tl.exp(max_logit - max_sample) / sum_exp_logits + min_range = tl.exp(min_logit - max_sample) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Third pass: Search for p_pivot + if sum_outlier_probs > p: + min_range = outlier_prob + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) + + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) + + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-15 or num_iters >= 24: + p_pivot = (max_range + min_range) / 2.0 + else: + # Re-populate the buffer with full softmax probabilities for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - - # Second pass: Calculate num_min_larger - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) - - # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - - # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: - min_range = p_pivot_1 - elif p_pivots_sum_0 > p: - min_range = p_pivot_0 - - if p_pivots_sum_0 < p: - max_range = p_pivot_0 - elif p_pivots_sum_1 < p: - max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 - - num_iters += 1 - if (max_range - min_range) < 1e-15 or num_iters >= 24: - p_pivot = (max_range + min_range) / 2.0 - - duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit - num_duplicate_logit = num_min_larger - num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) - num_kept = tl.zeros((), dtype=tl.uint32) - - # Top-p only path - final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample + while p_pivot == 1.0: + p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range + p_pivots_sum_2 = 0.0 + min_larger_2 = 1.0 + num_min_larger_2 = tl.zeros((), dtype=tl.uint32) + + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) + masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) + min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) + + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) + num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_2 >= p: + if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: + p_pivot = p_pivot_2 + min_larger_prob = min_larger_2 + num_min_larger = num_min_larger_2 + p_pivots_sum = p_pivots_sum_2 + if p_pivots_sum_1 >= p: + if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p: + if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + + # Update range + if p_pivots_sum_2 > p: + min_range = p_pivot_2 + elif p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + elif p_pivots_sum_2 < p: + max_range = p_pivot_2 + + num_iters += 1 + if (max_range - min_range) < 1e-15 or num_iters >= 24: + p_pivot = (max_range + min_range) / 2.0 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-p only path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample # Sixth pass: Apply mask if final_pivot != -float("inf"): From 150ccc6d0daa1e3b1ef8831e15c344ab29639961 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 19:02:14 -0800 Subject: [PATCH 84/99] Restored tests Signed-off-by: js_park --- tests/v1/sample/test_topk_topp_sampler.py | 144 +++++++++++----------- 1 file changed, 72 insertions(+), 72 deletions(-) diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index 5476cf1cf961..fc6412949dd9 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -202,19 +202,19 @@ def test_topk_only(self, batch_size: int, vocab_size: int): self._compare_results(logits, k, p=None) - # @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) - # @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) - # def test_topp_only(self, batch_size: int, vocab_size: int): - # """Test top-p only (k=None).""" - # logits = torch.randn( - # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - # ) - # p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0] - # # Randomly disable top-p for some rows (~25%) - # disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 - # p.masked_fill_(disable_mask, 1.0) - - # self._compare_results(logits, k=None, p=p) + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) + @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) + def test_topp_only(self, batch_size: int, vocab_size: int): + """Test top-p only (k=None).""" + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0] + # Randomly disable top-p for some rows (~25%) + disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0 + p.masked_fill_(disable_mask, 1.0) + + self._compare_results(logits, k=None, p=p) @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024]) @pytest.mark.parametrize("vocab_size", [1024, 32000, 128256]) @@ -237,62 +237,62 @@ def test_topk_and_topp(self, batch_size: int, vocab_size: int): self._compare_results(logits, k, p) - # def test_both_disabled(self): - # """Test when both k and p are None (should be no-op).""" - # from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton - - # logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32) - # logits_clone = logits.clone() - - # result = apply_top_k_top_p_triton(logits_clone, k=None, p=None) - - # assert torch.equal(result, logits), "Should be no-op when both k and p are None" - - # def test_extreme_k_values(self): - # """Test edge cases for k values.""" - # batch_size, vocab_size = 16, 1024 - # logits = torch.randn( - # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - # ) - - # # k=1 (keep only top 1) - # k = torch.ones(batch_size, dtype=torch.int32) - # self._compare_results(logits.clone(), k, p=None) - - # # k=vocab_size (keep all) - # k = torch.full((batch_size,), vocab_size, dtype=torch.int32) - # self._compare_results(logits.clone(), k, p=None) - - # # Mixed extreme values - # k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32) - # self._compare_results(logits.clone(), k, p=None) - - # def test_extreme_p_values(self): - # """Test edge cases for p values.""" - # batch_size, vocab_size = 16, 1024 - # logits = torch.randn( - # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - # ) - - # # p close to 0 (very restrictive) - # p = torch.full((batch_size,), 0.01, dtype=torch.float32) - # self._compare_results(logits.clone(), k=None, p=p) - - # # p=1.0 (keep all) - # p = torch.ones(batch_size, dtype=torch.float32) - # self._compare_results(logits.clone(), k=None, p=p) - - # # Mixed values - # p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32) - # self._compare_results(logits.clone(), k=None, p=p) - - # def test_large_batch(self): - # """Test with a large batch size.""" - # batch_size, vocab_size = 512, 32000 - # logits = torch.randn( - # batch_size, vocab_size, generator=self.generator, dtype=torch.float32 - # ) - # k = torch.randint(1, 50, (batch_size,), generator=self.generator) - # p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5 - - # self._compare_results(logits, k, p) + def test_both_disabled(self): + """Test when both k and p are None (should be no-op).""" + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32) + logits_clone = logits.clone() + + result = apply_top_k_top_p_triton(logits_clone, k=None, p=None) + + assert torch.equal(result, logits), "Should be no-op when both k and p are None" + + def test_extreme_k_values(self): + """Test edge cases for k values.""" + batch_size, vocab_size = 16, 1024 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + + # k=1 (keep only top 1) + k = torch.ones(batch_size, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + # k=vocab_size (keep all) + k = torch.full((batch_size,), vocab_size, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + # Mixed extreme values + k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32) + self._compare_results(logits.clone(), k, p=None) + + def test_extreme_p_values(self): + """Test edge cases for p values.""" + batch_size, vocab_size = 16, 1024 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + + # p close to 0 (very restrictive) + p = torch.full((batch_size,), 0.01, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + # p=1.0 (keep all) + p = torch.ones(batch_size, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + # Mixed values + p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32) + self._compare_results(logits.clone(), k=None, p=p) + + def test_large_batch(self): + """Test with a large batch size.""" + batch_size, vocab_size = 512, 32000 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + k = torch.randint(1, 50, (batch_size,), generator=self.generator) + p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5 + + self._compare_results(logits, k, p) From ae08705ad19323afb686d5ccaf7b1b90f0a1d728 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 19:15:19 -0800 Subject: [PATCH 85/99] Bugfix Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_triton.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index b52404fddd53..5871b67f5226 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -102,8 +102,7 @@ def _topk_topp_kernel( if TOPK_ENABLED: k = tl.load(K + row_id) - if k != VOCAB_SIZE: - + if k < VOCAB_SIZE: # Zeroth pass: Compute avg and std from a sample block offs = tl.arange(0, BLOCK_SIZE) mask_n = offs < VOCAB_SIZE @@ -324,6 +323,9 @@ def _topk_topp_kernel( num_duplicate_logit = num_min_larger num_keep = num_duplicate_logit - (k_pivots_num - k) num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k only path + final_pivot = k_pivot if TOPP_ENABLED: #### TOP-P SAMPLING AFTER TOP-K #### @@ -524,10 +526,7 @@ def _topk_topp_kernel( # Top-k + Top-p path final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - else: - # Top-k only path - final_pivot = k_pivot - elif TOPP_ENABLED: + if TOPP_ENABLED and final_pivot == -float("inf"): #### STANDALONE TOP-P SAMPLING #### p = tl.load(P + row_id) if p < 1.0: From 49c3c39bc3383b0afc4e39ab79b722a46d5e32cf Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 19:38:11 -0800 Subject: [PATCH 86/99] Loosened hyperparameters Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_triton.py | 168 +++++-------------------- 1 file changed, 33 insertions(+), 135 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 5871b67f5226..675198e71cd4 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -117,7 +117,7 @@ def _topk_topp_kernel( percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) percentile = tl.minimum(percentile, 199) sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) - sigma = sigma + tl.abs(sigma) * -0.25 + sigma = sigma + tl.abs(sigma) * -0.2 outlier_pivot = avg_logit + std_logit * sigma num_outliers = tl.zeros((), dtype=tl.uint32) @@ -151,21 +151,16 @@ def _topk_topp_kernel( (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 ) while k_pivot == float("inf"): - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) min_larger_0 = float("inf") num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) min_larger_1 = float("inf") num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - min_larger_2 = float("inf") - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - # First pass: Calculate k_pivots_num and min_larger for i in range(0, search_iters): offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) @@ -176,11 +171,9 @@ def _topk_topp_kernel( k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) - min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) # Second pass: Calculate num_min_larger for i in range(0, search_iters): @@ -190,9 +183,8 @@ def _topk_topp_kernel( BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") ) - num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-15) + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-9) # Check if any of the pivots satisfy termination condition if k_pivots_num_0 >= k: @@ -207,17 +199,9 @@ def _topk_topp_kernel( k_pivots_num = k_pivots_num_1 min_larger = min_larger_1 num_min_larger = num_min_larger_1 - if k_pivots_num_2 >= k: - if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: - k_pivot = k_pivot_2 - k_pivots_num = k_pivots_num_2 - min_larger = min_larger_2 - num_min_larger = num_min_larger_2 # Update range - if k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: + if k_pivots_num_1 > k: min_range = k_pivot_1 elif k_pivots_num_0 > k: min_range = k_pivot_0 @@ -226,11 +210,8 @@ def _topk_topp_kernel( max_range = k_pivot_0 elif k_pivots_num_1 < k: max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - num_iters += 1 - if num_iters >= 24 or tl.abs(min_range - max_range) < 1e-15: + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: k_pivot = (max_range + min_range) / 2.0 else: # If top-k outlier gathering failed, search whole logit space @@ -247,11 +228,6 @@ def _topk_topp_kernel( min_larger_1 = float("inf") num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - min_larger_2 = float("inf") - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - # First pass: Calculate k_pivots_num and min_larger for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -262,11 +238,9 @@ def _topk_topp_kernel( k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk2 > k_pivot_2) min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) - min_larger_2 = tl.minimum(min_larger_2, tl.min(logits_blk2)) # Second pass: Calculate num_min_larger for i in range(0, NUM_TILES): @@ -276,9 +250,8 @@ def _topk_topp_kernel( LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") ) - num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(logits_blk2 - min_larger_2) < 1e-15) + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-9) # Check if any of the pivots satisfy termination condition if k_pivots_num_0 >= k: @@ -293,17 +266,9 @@ def _topk_topp_kernel( k_pivots_num = k_pivots_num_1 min_larger = min_larger_1 num_min_larger = num_min_larger_1 - if k_pivots_num_2 >= k: - if k_pivots_num_2 - (min_larger_2 * num_min_larger_2) < k: - k_pivot = k_pivot_2 - k_pivots_num = k_pivots_num_2 - min_larger = min_larger_2 - num_min_larger = num_min_larger_2 # Update range - if k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: + if k_pivots_num_1 > k: min_range = k_pivot_1 elif k_pivots_num_0 > k: min_range = k_pivot_0 @@ -312,11 +277,9 @@ def _topk_topp_kernel( max_range = k_pivot_0 elif k_pivots_num_1 < k: max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 num_iters += 1 - if num_iters >= 24 or tl.abs(min_range - max_range) < 1e-15: + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: k_pivot = (max_range + min_range) / 2.0 duplicate_logit = min_larger @@ -353,7 +316,7 @@ def _topk_topp_kernel( # Duplicate logit handling if num_keep < num_duplicate_logit: - duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-15 + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 duplicate_count = tl.cumsum(duplicate_mask) + num_kept duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask @@ -392,7 +355,7 @@ def _topk_topp_kernel( outlier_mask = (probs_blk > min_logit) & mask_n # Duplicate logit handling - duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-15 + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 duplicate_count = tl.cumsum(duplicate_mask) + num_kept duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask @@ -435,22 +398,16 @@ def _topk_topp_kernel( # Fifth passes: Search for p_pivot while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range p_pivots_sum_1 = 0.0 min_larger_1 = 1.0 num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - # First pass: Calculate p_pivots_sum and min_larger for i in range(0, search_iters): offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) @@ -465,28 +422,16 @@ def _topk_topp_kernel( masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - # Second pass: Calculate num_min_larger for i in range(0, search_iters): offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) mask_n_2 = offs_n < search_range probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 if p_pivots_sum_1 >= p: if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: p_pivot = p_pivot_1 @@ -501,9 +446,7 @@ def _topk_topp_kernel( p_pivots_sum = p_pivots_sum_0 # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: + if p_pivots_sum_1 > p: min_range = p_pivot_1 elif p_pivots_sum_0 > p: min_range = p_pivot_0 @@ -512,11 +455,9 @@ def _topk_topp_kernel( max_range = p_pivot_0 elif p_pivots_sum_1 < p: max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 num_iters += 1 - if (max_range - min_range) < 1e-15 or num_iters >= 24: + if (max_range - min_range) < 1e-9 or num_iters >= 18: p_pivot = (max_range + min_range) / 2.0 duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit @@ -559,7 +500,7 @@ def _topk_topp_kernel( idx = tl.cast(p * 200, tl.int32) idx = tl.maximum(0, tl.minimum(idx, 199)) sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) - sigma = sigma + tl.abs(sigma) * -0.25 + sigma = sigma + tl.abs(sigma) * -0.2 outlier_pivot = avg_logit + std_logit * sigma outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits @@ -601,22 +542,16 @@ def _topk_topp_kernel( (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range p_pivots_sum_1 = 0.0 min_larger_1 = 1.0 num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - - # First pass: Calculate p_pivots_sum and min_larger for i in range(0, search_iters): offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) @@ -631,28 +566,16 @@ def _topk_topp_kernel( masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - # Second pass: Calculate num_min_larger for i in range(0, search_iters): offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) mask_n_2 = offs_n < search_range probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 if p_pivots_sum_1 >= p: if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: p_pivot = p_pivot_1 @@ -667,9 +590,7 @@ def _topk_topp_kernel( p_pivots_sum = p_pivots_sum_0 # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: + if p_pivots_sum_1 > p: min_range = p_pivot_1 elif p_pivots_sum_0 > p: min_range = p_pivot_0 @@ -678,11 +599,9 @@ def _topk_topp_kernel( max_range = p_pivot_0 elif p_pivots_sum_1 < p: max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 num_iters += 1 - if (max_range - min_range) < 1e-15 or num_iters >= 24: + if (max_range - min_range) < 1e-9 or num_iters >= 18: p_pivot = (max_range + min_range) / 2.0 else: # Re-populate the buffer with full softmax probabilities @@ -696,21 +615,16 @@ def _topk_topp_kernel( tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) while p_pivot == 1.0: - p_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - p_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range p_pivots_sum_1 = 0.0 min_larger_1 = 1.0 num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - p_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - p_pivots_sum_2 = 0.0 - min_larger_2 = 1.0 - num_min_larger_2 = tl.zeros((), dtype=tl.uint32) - # First pass: Calculate p_pivots_sum and min_larger for i in range(0, NUM_TILES): @@ -726,28 +640,16 @@ def _topk_topp_kernel( masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) - p_pivots_sum_2 += tl.sum(probs_blk * (probs_blk > p_pivot_2)) - masked_larger_2 = tl.where(probs_blk > p_pivot_2, probs_blk, 1.0) - min_larger_2 = tl.minimum(min_larger_2, tl.min(masked_larger_2)) - - # Second pass: Calculate num_min_larger for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-15) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-15) - num_min_larger_2 += tl.sum(tl.abs(probs_blk - min_larger_2) < 1e-15) + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) # Check if any of the pivots satisfy termination condition - if p_pivots_sum_2 >= p: - if p_pivots_sum_2 - (min_larger_2 * num_min_larger_2) < p: - p_pivot = p_pivot_2 - min_larger_prob = min_larger_2 - num_min_larger = num_min_larger_2 - p_pivots_sum = p_pivots_sum_2 if p_pivots_sum_1 >= p: if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: p_pivot = p_pivot_1 @@ -762,9 +664,7 @@ def _topk_topp_kernel( p_pivots_sum = p_pivots_sum_0 # Update range - if p_pivots_sum_2 > p: - min_range = p_pivot_2 - elif p_pivots_sum_1 > p: + if p_pivots_sum_1 > p: min_range = p_pivot_1 elif p_pivots_sum_0 > p: min_range = p_pivot_0 @@ -773,11 +673,9 @@ def _topk_topp_kernel( max_range = p_pivot_0 elif p_pivots_sum_1 < p: max_range = p_pivot_1 - elif p_pivots_sum_2 < p: - max_range = p_pivot_2 num_iters += 1 - if (max_range - min_range) < 1e-15 or num_iters >= 24: + if (max_range - min_range) < 1e-9 or num_iters >= 18: p_pivot = (max_range + min_range) / 2.0 duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit @@ -798,7 +696,7 @@ def _topk_topp_kernel( # Duplicate logit handling if num_keep < num_duplicate_logit: - duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-15) & mask_n + duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n duplicate_count = tl.cumsum(duplicate_mask) + num_kept duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask From 06565dfd8066416ff07138c528136bb6de5ef5d9 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 20:08:20 -0800 Subject: [PATCH 87/99] Linter Signed-off-by: js_park --- benchmarks/benchmark_topk_topp.py | 6 +- vllm/v1/sample/ops/topk_topp_triton.py | 556 ++++++++++++++++--------- 2 files changed, 368 insertions(+), 194 deletions(-) diff --git a/benchmarks/benchmark_topk_topp.py b/benchmarks/benchmark_topk_topp.py index dae0458d01ee..93a4bd485316 100644 --- a/benchmarks/benchmark_topk_topp.py +++ b/benchmarks/benchmark_topk_topp.py @@ -21,7 +21,10 @@ import torch from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch -from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton +from vllm.v1.sample.ops.topk_topp_triton import ( + apply_top_k_top_p_triton, + reset_buffer_cache, +) @dataclass @@ -78,6 +81,7 @@ def measure_memory() -> tuple[int, int]: def reset_memory_stats(): """Reset peak memory statistics.""" + reset_buffer_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() gc.collect() diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 675198e71cd4..cb807f4489bf 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -3,7 +3,7 @@ """ Combined Top-K and Top-P Triton kernels. -Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs +Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs using Pivot-based Truncation and Selection" By Park et al. """ @@ -12,14 +12,14 @@ from vllm.triton_utils import tl, triton - _TRITON_TABLE_CACHE: dict[ - tuple[torch.device, torch.dtype], (torch.Tensor, torch.Tensor) + tuple[torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor] ] = {} _TRITON_BUFFER_CACHE: dict[ tuple[torch.device, torch.dtype, int, int], torch.Tensor -] = {} +] = {} +# fmt: off _NORMAL_CDF_TO_SIGMA_TABLE = [ 3.656, 3.650, 3.650, 3.650, 3.626, 3.626, 3.626, 3.514, 3.514, 3.503, 3.503, 3.434, 3.434, 3.428, 3.428, 3.387, 3.380, 3.380, 3.376, 3.373, @@ -67,16 +67,17 @@ ] # fmt: on + @triton.jit def _topk_topp_kernel( - LOGITS, - BUFFER, + LOGITS, + BUFFER, PERCENTILE_TO_STD_TABLE, NORMAL_CDF_TO_SIGMA_TABLE, K, - P, + P, BATCH_SIZE, - VOCAB_SIZE: tl.constexpr, + VOCAB_SIZE: tl.constexpr, MASK_VALUE: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_TRUNC: tl.constexpr, @@ -87,7 +88,6 @@ def _topk_topp_kernel( pid = tl.program_id(0) num_programs = tl.num_programs(0) for row_id in tl.range(pid, BATCH_SIZE, num_programs): - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE BUFFER_ROW = BUFFER + pid * VOCAB_SIZE @@ -97,12 +97,12 @@ def _topk_topp_kernel( num_keep = tl.zeros((), dtype=tl.uint32) num_kept = tl.zeros((), dtype=tl.uint32) - max_logit = -float('inf') - min_logit = float('inf') + max_logit = -float("inf") + min_logit = float("inf") if TOPK_ENABLED: k = tl.load(K + row_id) - if k < VOCAB_SIZE: + if k < VOCAB_SIZE: # Zeroth pass: Compute avg and std from a sample block offs = tl.arange(0, BLOCK_SIZE) mask_n = offs < VOCAB_SIZE @@ -117,7 +117,7 @@ def _topk_topp_kernel( percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) percentile = tl.minimum(percentile, 199) sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) - sigma = sigma + tl.abs(sigma) * -0.2 + sigma = sigma + tl.abs(sigma) * -0.15 outlier_pivot = avg_logit + std_logit * sigma num_outliers = tl.zeros((), dtype=tl.uint32) @@ -125,14 +125,17 @@ def _topk_topp_kernel( for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit) + logits_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit + ) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) outlier_mask = (logits_blk > outlier_pivot) & mask_n cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) num_outliers += tl.sum(outlier_mask) write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) @@ -148,7 +151,8 @@ def _topk_topp_kernel( min_range = outlier_pivot search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, ) while k_pivot == float("inf"): k_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range @@ -163,7 +167,9 @@ def _topk_topp_kernel( # First pass: Calculate k_pivots_num and min_larger for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range logits_blk2 = tl.load( BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") @@ -177,29 +183,39 @@ def _topk_topp_kernel( # Second pass: Calculate num_min_larger for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range logits_blk2 = tl.load( BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") ) - num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-9) - num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-9) + num_min_larger_0 += tl.sum( + tl.abs(logits_blk2 - min_larger_0) < 1e-9 + ) + num_min_larger_1 += tl.sum( + tl.abs(logits_blk2 - min_larger_1) < 1e-9 + ) # Check if any of the pivots satisfy termination condition - if k_pivots_num_0 >= k: - if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: - k_pivot = k_pivot_0 - k_pivots_num = k_pivots_num_0 - min_larger = min_larger_0 - num_min_larger = num_min_larger_0 - if k_pivots_num_1 >= k: - if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: - k_pivot = k_pivot_1 - k_pivots_num = k_pivots_num_1 - min_larger = min_larger_1 - num_min_larger = num_min_larger_1 - + if ( + k_pivots_num_0 >= k + and k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k + ): + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + if ( + k_pivots_num_1 >= k + and k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k + ): + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + # Update range if k_pivots_num_1 > k: min_range = k_pivot_1 @@ -250,29 +266,37 @@ def _topk_topp_kernel( LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") ) - num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-9) - num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-9) + num_min_larger_0 += tl.sum( + tl.abs(logits_blk2 - min_larger_0) < 1e-9 + ) + num_min_larger_1 += tl.sum( + tl.abs(logits_blk2 - min_larger_1) < 1e-9 + ) # Check if any of the pivots satisfy termination condition - if k_pivots_num_0 >= k: - if k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k: - k_pivot = k_pivot_0 - k_pivots_num = k_pivots_num_0 - min_larger = min_larger_0 - num_min_larger = num_min_larger_0 - if k_pivots_num_1 >= k: - if k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k: - k_pivot = k_pivot_1 - k_pivots_num = k_pivots_num_1 - min_larger = min_larger_1 - num_min_larger = num_min_larger_1 + if ( + k_pivots_num_0 >= k + and k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k + ): + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + if ( + k_pivots_num_1 >= k + and k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k + ): + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 - # Update range + # Update range if k_pivots_num_1 > k: min_range = k_pivot_1 elif k_pivots_num_0 > k: min_range = k_pivot_0 - + if k_pivots_num_0 < k: max_range = k_pivot_0 elif k_pivots_num_1 < k: @@ -289,120 +313,162 @@ def _topk_topp_kernel( # Top-k only path final_pivot = k_pivot - + if TOPP_ENABLED: #### TOP-P SAMPLING AFTER TOP-K #### p = tl.load(P + row_id) if p < 1.0: - min_logit = k_pivot + min_logit = k_pivot sum_exp_logits = 0.0 num_outliers_2 = tl.zeros((), dtype=tl.uint32) search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32 + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, ) - # Third pass: Calculate exp logits and sum, gather top-k outliers + # Third pass: Calculate exp logits and sum, gather outliers if num_outliers > k: for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, - mask=mask_n_2, - other=-float('inf')) - + probs_blk = tl.load( + BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float("inf"), + ) + outlier_mask = (probs_blk > min_logit) & mask_n_2 # Duplicate logit handling if num_keep < num_duplicate_logit: - duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 - duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask - duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask - outlier_mask = outlier_mask & (~duplicate_remove_mask) + duplicate_mask = ( + tl.abs(probs_blk - duplicate_logit) < 1e-9 + ) + duplicate_count = ( + tl.cumsum(duplicate_mask) + num_kept + ) + duplicate_keep_mask = ( + duplicate_count <= num_keep + ) & duplicate_mask + duplicate_remove_mask = ( + duplicate_mask & ~duplicate_keep_mask + ) + outlier_mask = outlier_mask & ( + ~duplicate_remove_mask + ) num_kept += tl.sum(duplicate_keep_mask) - probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) + probs_blk = tl.where( + outlier_mask, probs_blk, -float("inf") + ) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) - + # Fourth pass: Calculate BUFFER and get outliers for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, - mask=mask_n_2, - other=-float('inf')) - + probs_blk = tl.load( + BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float("inf"), + ) + outlier_mask = (probs_blk > min_logit) & mask_n_2 probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) probs_blk = probs_blk / sum_exp_logits tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) else: - # If top-k outlier gathering failed, retry gathering using top-k pivot + # If top-k outlier gathering failed, + # retry gathering using top-k pivot for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=-float('inf')) - + probs_blk = tl.load( + LOGITS_ROW + offs_n, + mask=mask_n, + other=-float("inf"), + ) + outlier_mask = (probs_blk > min_logit) & mask_n # Duplicate logit handling - duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 + duplicate_mask = ( + tl.abs(probs_blk - duplicate_logit) < 1e-9 + ) duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask - duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + duplicate_keep_mask = ( + duplicate_count <= num_keep + ) & duplicate_mask + duplicate_remove_mask = ( + duplicate_mask & ~duplicate_keep_mask + ) outlier_mask = outlier_mask & (~duplicate_remove_mask) num_kept += tl.sum(duplicate_keep_mask) - probs_blk = tl.where(outlier_mask, probs_blk, -float('inf')) + probs_blk = tl.where( + outlier_mask, probs_blk, -float("inf") + ) probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) sum_exp_logits += tl.sum(probs_blk) - + cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers_2, tl.int32) + tl.cumsum(outlier_mask) - 1 + num_outliers_2, + tl.int32, + ) num_outliers_2 += tl.sum(outlier_mask) write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) - + tl.store( + BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask + ) + search_range = tl.cast(num_outliers_2, tl.int32) search_iters = tl.cast( - (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) + (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) + // BLOCK_SIZE_TRUNC, + tl.int32, + ) # Fourth pass: Calculate BUFFER and get outliers for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0 + ) probs_blk = probs_blk / sum_exp_logits tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) - max_range = tl.exp(max_logit - max_logit) / sum_exp_logits min_range = tl.exp(min_logit - max_logit) / sum_exp_logits - + p_pivot = 1.0 num_iters = 0 min_larger_prob = 1.0 num_min_larger = tl.zeros((), dtype=tl.uint32) p_pivots_sum = 0.0 - + # Fifth passes: Search for p_pivot while p_pivot == 1.0: p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range p_pivots_sum_1 = 0.0 min_larger_1 = 1.0 @@ -410,41 +476,67 @@ def _topk_topp_kernel( # First pass: Calculate p_pivots_sum and min_larger for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) - - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + probs_blk = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0 + ) + + p_pivots_sum_0 += tl.sum( + probs_blk * (probs_blk > p_pivot_0) + ) + masked_larger_0 = tl.where( + probs_blk > p_pivot_0, probs_blk, 1.0 + ) + min_larger_0 = tl.minimum( + min_larger_0, tl.min(masked_larger_0) + ) + + p_pivots_sum_1 += tl.sum( + probs_blk * (probs_blk > p_pivot_1) + ) + masked_larger_1 = tl.where( + probs_blk > p_pivot_1, probs_blk, 1.0 + ) + min_larger_1 = tl.minimum( + min_larger_1, tl.min(masked_larger_1) + ) # Second pass: Calculate num_min_larger for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0 + ) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + num_min_larger_0 += tl.sum( + tl.abs(probs_blk - min_larger_0) < 1e-9 + ) + num_min_larger_1 += tl.sum( + tl.abs(probs_blk - min_larger_1) < 1e-9 + ) # Check if any of the pivots satisfy termination condition - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - + if p_pivots_sum_1 >= p and ( + p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p + ): + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if p_pivots_sum_0 >= p and ( + p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p + ): + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + # Update range if p_pivots_sum_1 > p: min_range = p_pivot_1 @@ -460,9 +552,13 @@ def _topk_topp_kernel( if (max_range - min_range) < 1e-9 or num_iters >= 18: p_pivot = (max_range + min_range) / 2.0 - duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + duplicate_logit = ( + tl.log(min_larger_prob * sum_exp_logits) + max_logit + ) num_duplicate_logit = num_min_larger - num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_keep = num_duplicate_logit - tl.cast( + (p_pivots_sum - p) / min_larger_prob, tl.uint32 + ) num_kept = tl.zeros((), dtype=tl.uint32) # Top-k + Top-p path @@ -487,12 +583,12 @@ def _topk_topp_kernel( for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, - mask=mask_n, - other=avg_logit) + logits_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit + ) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - + probs_blk = tl.exp(logits_blk - max_sample) probs_blk = tl.where(mask_n, probs_blk, 0.0) sum_exp_logits += tl.sum(probs_blk) @@ -500,7 +596,7 @@ def _topk_topp_kernel( idx = tl.cast(p * 200, tl.int32) idx = tl.maximum(0, tl.minimum(idx, 199)) sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) - sigma = sigma + tl.abs(sigma) * -0.2 + sigma = sigma + tl.abs(sigma) * -0.25 outlier_pivot = avg_logit + std_logit * sigma outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits @@ -512,41 +608,45 @@ def _topk_topp_kernel( offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) probs_blk = tl.exp(probs_blk - max_sample) probs_blk = probs_blk / sum_exp_logits outlier_mask = (probs_blk > outlier_prob) & mask_n sum_outlier_probs += tl.sum(outlier_mask * probs_blk) cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 + ) num_outliers += tl.sum(outlier_mask) write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) - max_range = tl.exp(max_logit - max_sample) / sum_exp_logits min_range = tl.exp(min_logit - max_sample) / sum_exp_logits - + p_pivot = 1.0 num_iters = 0 min_larger_prob = 1.0 num_min_larger = tl.zeros((), dtype=tl.uint32) p_pivots_sum = 0.0 - + # Third pass: Search for p_pivot if sum_outlier_probs > p: min_range = outlier_prob search_range = tl.cast(num_outliers, tl.int32) search_iters = tl.cast( - (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32) + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) while p_pivot == 1.0: p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range p_pivots_sum_1 = 0.0 min_larger_1 = 1.0 @@ -554,41 +654,69 @@ def _topk_topp_kernel( # First pass: Calculate p_pivots_sum and min_larger for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0 + ) - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + p_pivots_sum_0 += tl.sum( + probs_blk * (probs_blk > p_pivot_0) + ) + masked_larger_0 = tl.where( + probs_blk > p_pivot_0, probs_blk, 1.0 + ) + min_larger_0 = tl.minimum( + min_larger_0, tl.min(masked_larger_0) + ) + + p_pivots_sum_1 += tl.sum( + probs_blk * (probs_blk > p_pivot_1) + ) + masked_larger_1 = tl.where( + probs_blk > p_pivot_1, probs_blk, 1.0 + ) + min_larger_1 = tl.minimum( + min_larger_1, tl.min(masked_larger_1) + ) # Second pass: Calculate num_min_larger for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( + 0, BLOCK_SIZE_TRUNC + ) mask_n_2 = offs_n < search_range - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = tl.load( + BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0 + ) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + num_min_larger_0 += tl.sum( + tl.abs(probs_blk - min_larger_0) < 1e-9 + ) + num_min_larger_1 += tl.sum( + tl.abs(probs_blk - min_larger_1) < 1e-9 + ) # Check if any of the pivots satisfy termination condition - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - + if ( + p_pivots_sum_1 >= p + and p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p + ): + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if ( + p_pivots_sum_0 >= p + and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p + ): + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + # Update range if p_pivots_sum_1 > p: min_range = p_pivot_1 @@ -609,60 +737,85 @@ def _topk_topp_kernel( offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + probs_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) probs_blk = tl.exp(probs_blk - max_sample) probs_blk = probs_blk / sum_exp_logits tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - + while p_pivot == 1.0: p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 num_min_larger_0 = tl.zeros((), dtype=tl.uint32) - + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range p_pivots_sum_1 = 0.0 min_larger_1 = 1.0 num_min_larger_1 = tl.zeros((), dtype=tl.uint32) - # First pass: Calculate p_pivots_sum and min_larger for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load( + BUFFER_ROW + offs_n, mask=mask_n, other=0.0 + ) - p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) - masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) - min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) - - p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) - masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) - min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + p_pivots_sum_0 += tl.sum( + probs_blk * (probs_blk > p_pivot_0) + ) + masked_larger_0 = tl.where( + probs_blk > p_pivot_0, probs_blk, 1.0 + ) + min_larger_0 = tl.minimum( + min_larger_0, tl.min(masked_larger_0) + ) + + p_pivots_sum_1 += tl.sum( + probs_blk * (probs_blk > p_pivot_1) + ) + masked_larger_1 = tl.where( + probs_blk > p_pivot_1, probs_blk, 1.0 + ) + min_larger_1 = tl.minimum( + min_larger_1, tl.min(masked_larger_1) + ) # Second pass: Calculate num_min_larger for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + probs_blk = tl.load( + BUFFER_ROW + offs_n, mask=mask_n, other=0.0 + ) - num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) - num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + num_min_larger_0 += tl.sum( + tl.abs(probs_blk - min_larger_0) < 1e-9 + ) + num_min_larger_1 += tl.sum( + tl.abs(probs_blk - min_larger_1) < 1e-9 + ) # Check if any of the pivots satisfy termination condition - if p_pivots_sum_1 >= p: - if p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: - p_pivot = p_pivot_1 - min_larger_prob = min_larger_1 - num_min_larger = num_min_larger_1 - p_pivots_sum = p_pivots_sum_1 - if p_pivots_sum_0 >= p: - if p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: - p_pivot = p_pivot_0 - min_larger_prob = min_larger_0 - num_min_larger = num_min_larger_0 - p_pivots_sum = p_pivots_sum_0 - + if ( + p_pivots_sum_1 >= p + and p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p + ): + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + if ( + p_pivots_sum_0 >= p + and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p + ): + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + # Update range if p_pivots_sum_1 > p: min_range = p_pivot_1 @@ -680,9 +833,11 @@ def _topk_topp_kernel( duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit num_duplicate_logit = num_min_larger - num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_keep = num_duplicate_logit - tl.cast( + (p_pivots_sum - p) / min_larger_prob, tl.uint32 + ) num_kept = tl.zeros((), dtype=tl.uint32) - + # Top-p only path final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample @@ -691,21 +846,28 @@ def _topk_topp_kernel( for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float('inf')) + logits_blk = tl.load( + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") + ) keep_mask = (logits_blk > final_pivot) & mask_n # Duplicate logit handling if num_keep < num_duplicate_logit: - duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n + duplicate_mask = ( + tl.abs(logits_blk - duplicate_logit) < 1e-9 + ) & mask_n duplicate_count = tl.cumsum(duplicate_mask) + num_kept - duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask + duplicate_keep_mask = ( + duplicate_count <= num_duplicate_logit + ) & duplicate_mask duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask num_kept += tl.sum(duplicate_keep_mask) keep_mask = keep_mask & (~duplicate_remove_mask) - + logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + def apply_top_k_top_p_triton( logits: torch.Tensor, k: torch.Tensor | None, @@ -775,11 +937,13 @@ def apply_top_k_top_p_triton( percentile_to_std_table = torch.tensor( _PERCENTILE_TO_STD_TABLE, device=logits.device, dtype=torch.float32 ) - _TRITON_TABLE_CACHE[tbl_key] = (normal_cdf_to_sigma_table, percentile_to_std_table) + _TRITON_TABLE_CACHE[tbl_key] = ( + normal_cdf_to_sigma_table, + percentile_to_std_table, + ) else: normal_cdf_to_sigma_table, percentile_to_std_table = tables - _topk_topp_kernel[(NUM_PROGRAMS,)]( logits, buffer, @@ -797,3 +961,9 @@ def apply_top_k_top_p_triton( ) return logits + + +def reset_buffer_cache(): + _TRITON_BUFFER_CACHE.clear() + _TRITON_TABLE_CACHE.clear() + torch.cuda.empty_cache() From acd99d713097ebb1e4b7f9babe7c33b539b322d0 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 20:36:05 -0800 Subject: [PATCH 88/99] Restore Signed-off-by: js_park --- vllm/envs.py | 637 ++++++++++++--------- vllm/v1/sample/ops/topk_topp_sampler.py | 700 +----------------------- 2 files changed, 420 insertions(+), 917 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index d0e50e979396..741a2163c91f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -import hashlib import json +import logging import os import sys import tempfile +import uuid from collections.abc import Callable from typing import TYPE_CHECKING, Any, Literal @@ -19,11 +20,10 @@ VLLM_NCCL_SO_PATH: str | None = None LD_LIBRARY_PATH: str | None = None VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE: int = 256 - VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False - VLLM_FLASH_ATTN_VERSION: int | None = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: str | None = None VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 + VLLM_ENGINE_READY_TIMEOUT_S: int = 600 VLLM_API_KEY: str | None = None VLLM_DEBUG_LOG_API_SERVER_RESPONSE: bool = False S3_ACCESS_KEY_ID: str | None = None @@ -34,28 +34,26 @@ VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm") VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" VLLM_NO_USAGE_STATS: bool = False - VLLM_DISABLE_FLASHINFER_PREFILL: bool = False VLLM_DO_NOT_TRACK: bool = False VLLM_USAGE_SOURCE: str = "" - VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_CONFIGURE_LOGGING: bool = True VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_STREAM: str = "ext://sys.stdout" VLLM_LOGGING_CONFIG_PATH: str | None = None + VLLM_LOGGING_COLOR: str = "auto" + NO_COLOR: bool = False VLLM_LOG_STATS_INTERVAL: float = 10.0 VLLM_TRACE_FUNCTION: int = 0 - VLLM_ATTENTION_BACKEND: str | None = None VLLM_USE_FLASHINFER_SAMPLER: bool | None = None - VLLM_USE_TRITON_SAMPLER: bool | None = None VLLM_PP_LAYER_PARTITION: str | None = None VLLM_CPU_KVCACHE_SPACE: int | None = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None - VLLM_CPU_MOE_PREPACK: bool = True VLLM_CPU_SGL_KERNEL: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False - VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 + VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024 VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False @@ -72,14 +70,15 @@ VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MEDIA_CONNECTOR: str = "http" - VLLM_MM_INPUT_CACHE_GIB: int = 4 + VLLM_MM_HASHER_ALGORITHM: str = "blake3" VLLM_TARGET_DEVICE: str = "cuda" - VLLM_MAIN_CUDA_VERSION: str = "12.8" + VLLM_MAIN_CUDA_VERSION: str = "12.9" + VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" MAX_JOBS: str | None = None NVCC_THREADS: str | None = None VLLM_USE_PRECOMPILED: bool = False + VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX: bool = False VLLM_DOCKER_BUILD_CONTEXT: bool = False - VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False CMAKE_BUILD_TYPE: Literal["Debug", "Release", "RelWithDebInfo"] | None = None VERBOSE: bool = False @@ -88,15 +87,25 @@ VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_PLUGINS: list[str] | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None - VLLM_TORCH_CUDA_PROFILE: bool = False + VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None + # Deprecated env variables for profiling, kept for backward compatibility + # See also vllm/config/profiler.py and `--profiler-config` argument + VLLM_TORCH_CUDA_PROFILE: str | None = None VLLM_TORCH_PROFILER_DIR: str | None = None - VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False - VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False + VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None + VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None + VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None + VLLM_TORCH_PROFILER_WITH_STACK: str | None = None + VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None + VLLM_TORCH_PROFILER_USE_GZIP: str | None = None + VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None + VLLM_PROFILER_DELAY_ITERS: str | None = None + VLLM_PROFILER_MAX_ITERS: str | None = None + # End of deprecated env variables for profiling VLLM_USE_AOT_COMPILE: bool = False VLLM_USE_BYTECODE_HOOK: bool = False VLLM_FORCE_AOT_LOAD: bool = False - VLLM_TORCH_PROFILER_WITH_STACK: bool = True - VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False + VLLM_USE_MEGA_AOT_ARTIFACT: bool = False VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False @@ -112,13 +121,15 @@ VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True + VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False - VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -128,7 +139,6 @@ VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False - VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 32 VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: str | None = None @@ -139,10 +149,13 @@ VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 + VLLM_ENABLE_MOE_DP_CHUNK: bool = True VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None VLLM_MXFP4_USE_MARLIN: bool | None = None + VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: int | None = None @@ -150,27 +163,34 @@ VLLM_USE_DEEP_GEMM: bool = True VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True + VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", "relax", ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True + VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = False VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False - VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency" + VLLM_USE_FLASHINFER_MOE_INT4: bool = False + VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( + "latency" + ) VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 + VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998 VLLM_ALL2ALL_BACKEND: Literal[ "naive", "pplx", "deepep_high_throughput", "deepep_low_latency", + "mori", "allgather_reducescatter", "flashinfer_all2allv", ] = "allgather_reducescatter" @@ -188,15 +208,16 @@ VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 - VLLM_USE_CUDNN_PREFILL: bool = False - VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False + VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False + VLLM_MORIIO_QP_PER_TRANSFER: int = 1 + VLLM_MORIIO_POST_BATCH_SIZE: int = -1 + VLLM_MORIIO_NUM_WORKERS: int = 1 + VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" - VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False + VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True VLLM_ENABLE_RESPONSES_API_STORE: bool = False - VLLM_USE_TRTLLM_ATTENTION: str | None = None VLLM_NVFP4_GEMM_BACKEND: str | None = None - VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False @@ -205,6 +226,7 @@ VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set() + VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False @@ -223,10 +245,15 @@ VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" + VLLM_DEBUG_WORKSPACE: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" - VLLM_FLAT_LOGPROBS: bool = False + VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_LOG_MODEL_INSPECTION: bool = False + VLLM_DEBUG_MFU_METRICS: bool = False + VLLM_DISABLE_LOG_LOGO: bool = False + VLLM_LORA_DISABLE_PDL: bool = False def get_default_cache_root(): @@ -412,9 +439,9 @@ def get_vllm_port() -> int | None: try: return int(port) except ValueError as err: - from urllib.parse import urlparse + from urllib3.util import parse_url - parsed = urlparse(port) + parsed = parse_url(port) if parsed.scheme: raise ValueError( f"VLLM_PORT '{port}' appears to be a URI. " @@ -424,20 +451,50 @@ def get_vllm_port() -> int | None: raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err +def get_env_or_set_default( + env_name: str, + default_factory: Callable[[], str], +) -> Callable[[], str]: + """ + Create a lambda that returns an environment variable value if set, + or generates and sets a default value using the provided factory function. + """ + + def _get_or_set_default() -> str: + value = os.getenv(env_name) + if value is not None: + return value + + default_value = default_factory() + os.environ[env_name] = default_value + return default_value + + return _get_or_set_default + + # The start-* and end* here are used by the documentation generator # to extract the used env vars. # --8<-- [start:env-vars-definition] +logger = logging.getLogger(__name__) + environment_variables: dict[str, Callable[[], Any]] = { # ================== Installation Time Env Vars ================== # Target device of vLLM, supporting [cuda (by default), # rocm, cpu] "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), - # Main CUDA version of vLLM, supporting [12.6, 12.8, 12.9], - # 12.8 is the default. This follows PyTorch but can be overridden. + # Main CUDA version of vLLM. This follows PyTorch but can be overridden. "VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() - or "12.8", + or "12.9", + # Controls PyTorch float32 matmul precision mode within vLLM workers. + # Valid options mirror torch.set_float32_matmul_precision + "VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices( + "VLLM_FLOAT32_MATMUL_PRECISION", + "highest", + ["highest", "high", "medium"], + case_sensitive=False, + ), # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), @@ -451,17 +508,16 @@ def get_vllm_port() -> int | None: .lower() in ("1", "true") or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + # If set, skip adding +precompiled suffix to version string + "VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX": lambda: bool( + int(os.environ.get("VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX", "0")) + ), # Used to mark that setup.py is running in a Docker build context, # in order to force the use of precompiled binaries. "VLLM_DOCKER_BUILD_CONTEXT": lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "") .strip() .lower() in ("1", "true"), - # Whether to force using nightly wheel in python build. - # This is used for testing the nightly wheel in python build. - "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": lambda: bool( - int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) - ), # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" @@ -528,17 +584,6 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE": lambda: int( os.environ.get("VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE", "256") ), - # Use separate prefill and decode kernels for V1 attention instead of - # the unified triton kernel. - "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: ( - os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() - in ("true", "1") - ), - # Force vllm to use a specific flash-attention version (2 or 3), only valid - # when using the flash-attention backend. - "VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int( - os.environ.get("VLLM_FLASH_ATTN_VERSION", None) - ), # Feature flag to enable/disable Inductor standalone compile. # In torch <= 2.7 we ignore this flag; in torch >= 2.9 this is # enabled by default. @@ -567,6 +612,13 @@ def get_vllm_port() -> int | None: # to load will result in a hard error when this is enabled. # Will be ignored when VLLM_USE_AOT_COMPILE is disabled. "VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1", + # Enable loading compiled models directly from cached standalone compile artifacts + # without re-splitting graph modules. This reduces overhead during model + # loading by using reconstruct_serializable_fn_from_mega_artifact. + "VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get( + "VLLM_USE_MEGA_AOT_ARTIFACT", "0" + ) + == "1", # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), @@ -576,6 +628,11 @@ def get_vllm_port() -> int | None: "VLLM_ENGINE_ITERATION_TIMEOUT_S": lambda: int( os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60") ), + # Timeout in seconds for waiting for engine cores to become ready + # during startup. Default is 600 seconds (10 minutes). + "VLLM_ENGINE_READY_TIMEOUT_S": lambda: int( + os.environ.get("VLLM_ENGINE_READY_TIMEOUT_S", "600") + ), # API key for vLLM API server "VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None), # Whether to log responses from API Server for debugging @@ -592,10 +649,6 @@ def get_vllm_port() -> int | None: "VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai" ), "VLLM_NO_USAGE_STATS": lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", - "VLLM_DISABLE_FLASHINFER_PREFILL": lambda: os.environ.get( - "VLLM_DISABLE_FLASHINFER_PREFILL", "0" - ) - == "1", "VLLM_DO_NOT_TRACK": lambda: ( os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get("DO_NOT_TRACK", None) @@ -607,7 +660,9 @@ def get_vllm_port() -> int | None: # If set to 0, vllm will not configure logging # If set to 1, vllm will configure logging using the default configuration # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH - "VLLM_CONFIGURE_LOGGING": lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_CONFIGURE_LOGGING": lambda: bool( + int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) + ), "VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), # this is used for configuring the default logging level "VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), @@ -615,6 +670,11 @@ def get_vllm_port() -> int | None: "VLLM_LOGGING_STREAM": lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"), # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages "VLLM_LOGGING_PREFIX": lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), + # Controls colored logging output. Options: "auto" (default, colors when terminal), + # "1" (always use colors), "0" (never use colors) + "VLLM_LOGGING_COLOR": lambda: os.getenv("VLLM_LOGGING_COLOR", "auto"), + # Standard unix flag for disabling ANSI color codes + "NO_COLOR": lambda: os.getenv("NO_COLOR", "0") != "0", # If set, vllm will log stats at this interval in seconds # If not set, vllm will log stats every 10 seconds. "VLLM_LOG_STATS_INTERVAL": lambda: val @@ -624,38 +684,12 @@ def get_vllm_port() -> int | None: # If set to 1, vllm will trace function calls # Useful for debugging "VLLM_TRACE_FUNCTION": lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), - # Backend for attention computation - # Example options: - # - "TORCH_SDPA": use torch.nn.MultiheadAttention - # - "FLASH_ATTN": use FlashAttention - # - "XFORMERS": use XFormers - # - "FLASHINFER": use flashinfer - # - "FLASHMLA": use FlashMLA - # - "FLASH_ATTN_MLA": use FlashAttention for MLA - # - "FLASHINFER_MLA": use FlashInfer for MLA - # - "CUTLASS_MLA": use CUTLASS for MLA - # All possible options loaded dynamically from AttentionBackendEnum - "VLLM_ATTENTION_BACKEND": env_with_choices( - "VLLM_ATTENTION_BACKEND", - None, - lambda: list( - __import__( - "vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"] - ).AttentionBackendEnum.__members__.keys() - ), - ), # If set, vllm will use flashinfer sampler "VLLM_USE_FLASHINFER_SAMPLER": lambda: bool( int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]) ) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, - # If set, vllm will use triton sampler. This will override the flashinfer sampler. - "VLLM_USE_TRITON_SAMPLER": lambda: bool( - int(os.environ.get("VLLM_USE_TRITON_SAMPLER", "0")) - ) - if "VLLM_USE_TRITON_SAMPLER" in os.environ - else None, # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), # (CPU backend only) CPU key-value cache space. @@ -673,10 +707,6 @@ def get_vllm_port() -> int | None: ) if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None, - # (CPU backend only) whether to use prepack for MoE layer. This will be - # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might - # need to set this to "0" (False). - "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), # (CPU backend only) whether to use SGL kernels, optimized for small batch. "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), # If the env var is set, Ray Compiled Graph uses the specified @@ -749,6 +779,7 @@ def get_vllm_port() -> int | None: ), # Backend for Video IO # - "opencv": Default backend that uses OpenCV stream buffered backend. + # - "identity": Returns raw video bytes for model processor to handle. # # Custom backend implementations can be registered # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and @@ -765,9 +796,17 @@ def get_vllm_port() -> int | None: # imported at runtime. # If a non-existing backend is used, an AssertionError will be thrown. "VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"), - # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache - # Default is 4 GiB per API process + 4 GiB per engine core process - "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), + # Hash algorithm for multimodal content hashing. + # - "blake3": Default, fast cryptographic hash (not FIPS 140-3 compliant) + # - "sha256": FIPS 140-3 compliant, widely supported + # - "sha512": FIPS 140-3 compliant, faster on 64-bit systems + # Use sha256 or sha512 for FIPS compliance in government/enterprise deployments + "VLLM_MM_HASHER_ALGORITHM": env_with_choices( + "VLLM_MM_HASHER_ALGORITHM", + "blake3", + ["blake3", "sha256", "sha512"], + case_sensitive=False, + ), # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. "VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser( @@ -783,7 +822,7 @@ def get_vllm_port() -> int | None: # Enable SPMD mode for TPU backend. "VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int( - os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768") + os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024)) ), # Control whether to use fused MoE activation chunking. Current chunking # logic is incompatible with torch.compile and causes IMA. See issue @@ -832,47 +871,59 @@ def get_vllm_port() -> int | None: "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( "VLLM_LORA_RESOLVER_CACHE_DIR", None ), - # Enables torch CUDA profiling if set. - # On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered. - "VLLM_TORCH_CUDA_PROFILE": lambda: bool( - os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0" + # A remote HF repo(s) containing one or more LoRA adapters, which + # may be downloaded and leveraged as needed. Only works if plugins + # are enabled and VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled. + # Values should be comma separated. + "VLLM_LORA_RESOLVER_HF_REPO_LIST": lambda: os.getenv( + "VLLM_LORA_RESOLVER_HF_REPO_LIST", None ), + # Enables torch CUDA profiling if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"), # Enables torch profiler if set. - # Both AsyncLLM's CPU traces as well as workers' - # traces (CPU & GPU) will be saved under this directory. - # Note that it must be an absolute path. - "VLLM_TORCH_PROFILER_DIR": lambda: ( - None - if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None - else ( - val - if val.startswith("gs://") and val[5:] and val[5] != "/" - else os.path.abspath(os.path.expanduser(val)) - ) - ), - # Enable torch profiler to record shapes if set - # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will - # not record shapes. - "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0" - ), - # Enable torch profiler to profile memory if set - # VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler - # will not profile memory. - "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0" - ), - # Enable torch profiler to profile stack if set - # VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL - # profile stack by default. - "VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0" - ), - # Enable torch profiler to profile flops if set - # VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will - # not profile flops. - "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0" + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"), + # Enable torch profiler to record shapes if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES") + ), + # Enable torch profiler to profile memory if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY") + ), + # Enable torch profiler to profile stack if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_STACK": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_STACK") + ), + # Enable torch profiler to profile flops if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS") + ), + # Disable torch profiling of the AsyncLLMEngine process if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM") + ), + # Delay number of iterations before starting profiling when using + # the torch/torch CUDA profiler. If set to 0, will start profiling immediately. + # Deprecated, see profiler_config. + "VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")), + # Maximum number of iterations to profile when using the torch/torch CUDA profiler. + # If set to 0, will not limit the number of iterations. + "VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"), + # Control whether torch profiler gzip-compresses profiling files. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"), + # Control whether torch profiler dumps the self_cuda_time_total table. + # Set to 0 to disable dumping the table. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL") ), # If set, vLLM will use Triton implementations of AWQ. "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), @@ -949,15 +1000,20 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") ), + # Whether to use aiter triton fp4 bmm kernel + # By default is enabled. + "VLLM_ROCM_USE_AITER_FP4BMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1") + ), # Use AITER triton unified attention for V1 attention "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1") ), # Whether to use aiter fusion shared experts ops. - # By default is enabled. + # By default is disabled. "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower() + os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "False").lower() in ("true", "1") ), # Whether to use aiter triton kernels for gemm ops. @@ -977,6 +1033,10 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: ( os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1") ), + # Whether to use the shuffled kv cache layout + "VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: ( + os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1") + ), # Custom quick allreduce kernel for MI3* cards # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce @@ -1033,10 +1093,6 @@ def get_vllm_port() -> int | None: # If set, vLLM will disable the MLA attention optimizations. "VLLM_MLA_DISABLE": lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), # If set, vLLM will pick up the provided Flash Attention MLA - # max number splits for cuda graph decode - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": lambda: int( - os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "32") - ), # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, # so that users can colocate other actors on the same GPUs as vLLM. @@ -1069,6 +1125,9 @@ def get_vllm_port() -> int | None: # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE # units. "VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), + "VLLM_ENABLE_MOE_DP_CHUNK": lambda: bool( + int(os.getenv("VLLM_ENABLE_MOE_DP_CHUNK", "1")) + ), # Randomize inputs during dummy runs when using Data Parallel "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get( "VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0" @@ -1108,6 +1167,16 @@ def get_vllm_port() -> int | None: "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( os.environ.get("VLLM_MXFP4_USE_MARLIN", None) ), + # The activation dtype for marlin kernel + "VLLM_MARLIN_INPUT_DTYPE": env_with_choices( + "VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"] + ), + # Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method + # only supported on Blackwell GPUs and with + # https://github.com/deepseek-ai/DeepEP/pull/341 + "VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool( + int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0")) + ), # Whether to turn on the outlines cache for V1 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. @@ -1139,6 +1208,10 @@ def get_vllm_port() -> int | None: "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) ), + # Whether to create TMA-aligned scale tensor when DeepGEMM is used. + "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( + int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -1163,18 +1236,27 @@ def get_vllm_port() -> int | None: "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1")) ), - # Allow use of FlashInfer MoE kernels for fused moe ops. + # Allow use of FlashInfer FP8 block-scale GEMM for linear layers. + # This uses TensorRT-LLM kernels and requires SM90+ (Hopper). + "VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool( + int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0")) + ), + # Allow use of FlashInfer BF16 MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0")) ), - # Allow use of FlashInfer MoE kernels for fused moe ops. + # Allow use of FlashInfer FP8 MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0")) ), - # Allow use of FlashInfer CUTLASS kernels for fused moe ops. + # Allow use of FlashInfer NVFP4 MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0")) ), + # Allow use of FlashInfer MxInt4 MoE kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE_INT4": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_INT4", "0")) + ), # If set to 1, use the FlashInfer # MXFP8 (activation) x MXFP4 (weight) MoE backend. "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool( @@ -1220,7 +1302,12 @@ def get_vllm_port() -> int | None: "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int( os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600") ), - # all2all backend for vllm's expert parallel communication + # Port used for Mooncake handshake between remote agents. + "VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int( + os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998") + ), + # [DEPRECATED - will be removed in v0.15.0] all2all backend for vllm's + # expert parallel communication. Use --all2all-backend CLI argument instead. # Available options: # - "naive": naive all2all implementation using broadcasts # - "allgather_reducescatter": all2all implementation based on allgather and @@ -1228,15 +1315,17 @@ def get_vllm_port() -> int | None: # - "pplx": use pplx kernels # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels + # - "mori", use MoRI kernels # - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl "VLLM_ALL2ALL_BACKEND": env_with_choices( "VLLM_ALL2ALL_BACKEND", - "allgather_reducescatter", + None, [ "naive", "pplx", "deepep_high_throughput", "deepep_low_latency", + "mori", "allgather_reducescatter", "flashinfer_all2allv", ], @@ -1249,7 +1338,9 @@ def get_vllm_port() -> int | None: # - "latency": # Uses TensorRT-LLM kernels optimized for low-latency inference. "VLLM_FLASHINFER_MOE_BACKEND": env_with_choices( - "VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"] + "VLLM_FLASHINFER_MOE_BACKEND", + "latency", + ["throughput", "latency", "masked_gemm"], ), # Control the workspace buffer size for the FlashInfer backend. "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int( @@ -1274,7 +1365,7 @@ def get_vllm_port() -> int | None: # MoE routing strategy selector. # See `RoutingSimulator.get_available_strategies()` # for available # strategies. - # Cutstom routing strategies can be registered by + # Custom routing strategies can be registered by # RoutingSimulator.register_strategy() # Note: custom strategies may not produce correct model outputs "VLLM_MOE_ROUTING_SIMULATION_STRATEGY": lambda: os.environ.get( @@ -1327,25 +1418,23 @@ def get_vllm_port() -> int | None: "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") ), - # Controls whether or not to use cudnn prefill - "VLLM_USE_CUDNN_PREFILL": lambda: bool( - int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) + # Controls the read mode for the Mori-IO connector + "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( + os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1") ), - # Controls whether to use TRT-LLM ragged DeepSeek prefill - "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool( - int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0")) + # Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector + "VLLM_MORIIO_QP_PER_TRANSFER": lambda: int( + os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1") ), - # If set to 1/True, use the TRTLLM attention backend in flashinfer. - # If set to 0/False, use the default attention backend in flashinfer. - # If not set, auto-detect the attention backend in flashinfer. - "VLLM_USE_TRTLLM_ATTENTION": lambda: ( - None - if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ - else os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true") + # Controls the post-processing batch size for the Mori-IO connector + "VLLM_MORIIO_POST_BATCH_SIZE": lambda: int( + os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1") ), - # If set to 1, when we use fp8 kv, we do not quantize Q to fp8 - "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION": lambda: bool( - int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0")) + # Controls the number of workers for Mori operations for the Mori-IO connector + "VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")), + # Timeout (in seconds) for MooncakeConnector in PD disaggregated setup. + "VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int( + os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480") ), # If set, it means we pre-downloaded cubin files and flashinfer will # read the cubin files directly. @@ -1356,11 +1445,18 @@ def get_vllm_port() -> int | None: # - "flashinfer-cudnn": use flashinfer cudnn GEMM backend # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend + # - "marlin": use marlin GEMM backend (for GPUs without native FP4 support) # - : automatically pick an available backend "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( "VLLM_NVFP4_GEMM_BACKEND", None, - ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"], + [ + "flashinfer-cudnn", + "flashinfer-trtllm", + "flashinfer-cutlass", + "cutlass", + "marlin", + ], ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. @@ -1382,7 +1478,7 @@ def get_vllm_port() -> int | None: # kv-cache memory usage and enable longer contexts) # TODO(lucas): Remove this flag once latency regression is resolved. "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool( - int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0")) + int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1")) ), # Enables support for the "store" option in the OpenAI Responses API. # When set to 1, vLLM's OpenAI server will retain the input and output @@ -1404,6 +1500,10 @@ def get_vllm_port() -> int | None: "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) ), + # Experimental: use this to enable MCP tool calling for non harmony models + "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool( + int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0")) + ), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), # Valid values are container,code_interpreter,web_search_preview @@ -1440,8 +1540,11 @@ def get_vllm_port() -> int | None: ), # Name of the shared memory buffer used for object storage. # Only effective when mm_config.mm_processor_cache_type == "shm". - "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": lambda: os.getenv( - "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER" + # Automatically generates a unique UUID-based name per process tree + # if not explicitly set. + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": get_env_or_set_default( + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", + lambda: f"VLLM_OBJECT_STORAGE_SHM_BUFFER_{uuid.uuid4().hex}", ), # The size in MB of the buffers (NVL and RDMA) used by DeepEP "VLLM_DEEPEP_BUFFER_SIZE_MB": lambda: int( @@ -1486,6 +1589,9 @@ def get_vllm_port() -> int | None: # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), + # Debug workspace allocations. + # logging of workspace resize operations. + "VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))), # Disables parallel execution of shared_experts via separate cuda stream "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) @@ -1506,13 +1612,28 @@ def get_vllm_port() -> int | None: "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] ), - # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than - # the original list[dict[int, Logprob]] approach. - # After enabled, PromptLogprobs and SampleLogprobs would populated as - # FlatLogprobs. - "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))), + # Flag to enable v2 model runner. + "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( + int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) + ), + # Log model inspection after loading. + # If enabled, logs a transformers-style hierarchical view of the model + # with quantization methods and attention backends. + "VLLM_LOG_MODEL_INSPECTION": lambda: bool( + int(os.getenv("VLLM_LOG_MODEL_INSPECTION", "0")) + ), + # Debug logging for --enable-mfu-metrics + "VLLM_DEBUG_MFU_METRICS": lambda: bool( + int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0")) + ), + # Disable logging of vLLM logo at server startup time. + "VLLM_DISABLE_LOG_LOGO": lambda: bool(int(os.getenv("VLLM_DISABLE_LOG_LOGO", "0"))), + # Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes + # Triton compilation to fail. + "VLLM_LORA_DISABLE_PDL": lambda: bool(int(os.getenv("VLLM_LORA_DISABLE_PDL", "0"))), } + # --8<-- [end:env-vars-definition] @@ -1528,6 +1649,12 @@ def __getattr__(name: str): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +def _is_envs_cache_enabled() -> bool: + """Checked if __getattr__ is wrapped with functools.cache""" + global __getattr__ + return hasattr(__getattr__, "cache_clear") + + def enable_envs_cache() -> None: """ Enables caching of environment variables. This is useful for performance @@ -1538,6 +1665,9 @@ def enable_envs_cache() -> None: runtime overhead. This also means that environment variables should NOT be updated after the service is initialized. """ + if _is_envs_cache_enabled(): + # Avoid wrapping functools.cache multiple times + return # Tag __getattr__ with functools.cache global __getattr__ __getattr__ = functools.cache(__getattr__) @@ -1547,6 +1677,18 @@ def enable_envs_cache() -> None: __getattr__(key) +def disable_envs_cache() -> None: + """ + Resets the environment variables cache. It could be used to isolate environments + between unit tests. + """ + global __getattr__ + # If __getattr__ is wrapped by functions.cache, unwrap the caching layer. + if _is_envs_cache_enabled(): + assert hasattr(__getattr__, "__wrapped__") + __getattr__ = __getattr__.__wrapped__ + + def __dir__(): return list(environment_variables.keys()) @@ -1558,86 +1700,91 @@ def is_set(name: str): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -def compute_hash() -> str: - """ - WARNING: Whenever a new key is added to this environment - variables, ensure that it is included in the factors list if - it affects the computation graph. For example, different values - of VLLM_PP_LAYER_PARTITION will generate different computation - graphs, so it is included in the factors list. The env vars that - affect the choice of different kernels or attention backends should - also be included in the factors list. - """ - - # The values of envs may affects the computation graph. - # TODO(DefTruth): hash all environment variables? - # for key in environment_variables: - # factorize(key) - environment_variables_to_hash = [ - "VLLM_PP_LAYER_PARTITION", - "VLLM_MLA_DISABLE", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", - "VLLM_USE_TRITON_AWQ", - "VLLM_DP_RANK", - "VLLM_DP_SIZE", - "VLLM_USE_STANDALONE_COMPILE", - "VLLM_FUSED_MOE_CHUNK_SIZE", - "VLLM_FLASHINFER_MOE_BACKEND", - "VLLM_V1_USE_PREFILL_DECODE_ATTENTION", - "VLLM_ATTENTION_BACKEND", - "VLLM_USE_FLASHINFER_SAMPLER", - "VLLM_USE_TRITON_SAMPLER", - "VLLM_DISABLED_KERNELS", - "VLLM_USE_DEEP_GEMM", - "VLLM_MOE_USE_DEEP_GEMM", - "VLLM_USE_DEEP_GEMM_E8M0", - "VLLM_USE_FUSED_MOE_GROUPED_TOPK", - "VLLM_USE_FLASHINFER_MOE_FP16", - "VLLM_USE_FLASHINFER_MOE_FP8", - "VLLM_USE_FLASHINFER_MOE_FP4", - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", - "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", - "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", - "VLLM_USE_CUDNN_PREFILL", - "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", - "VLLM_USE_TRTLLM_ATTENTION", - "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", - "VLLM_ROCM_USE_AITER", - "VLLM_ROCM_USE_AITER_PAGED_ATTN", - "VLLM_ROCM_USE_AITER_LINEAR", - "VLLM_ROCM_USE_AITER_MOE", - "VLLM_ROCM_USE_AITER_RMSNORM", - "VLLM_ROCM_USE_AITER_MLA", - "VLLM_ROCM_USE_AITER_MHA", - "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", - "VLLM_ROCM_USE_AITER_TRITON_ROPE", - "VLLM_ROCM_USE_AITER_FP8BMM", - "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", - "VLLM_ROCM_USE_AITER_TRITON_GEMM", - "VLLM_ROCM_USE_SKINNY_GEMM", - "VLLM_ROCM_FP8_PADDING", - "VLLM_ROCM_MOE_PADDING", - "VLLM_ROCM_CUSTOM_PAGED_ATTN", - "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", - "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", - "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", - "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", - "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", - "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", - "VLLM_NVFP4_GEMM_BACKEND", - "VLLM_USE_FBGEMM", - "VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE", - "VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL", - ] - for key in environment_variables_to_hash: - # if this goes out of sync with environment_variables, - # it's not a user error, it's a bug - assert key in environment_variables, ( - "Please update environment_variables_to_hash in envs.py" - ) +def compile_factors() -> dict[str, object]: + """Return env vars used for torch.compile cache keys. + + Start with every known vLLM env var; drop entries in `ignored_factors`; + hash everything else. This keeps the cache key aligned across workers.""" + + ignored_factors: set[str] = { + "MAX_JOBS", + "VLLM_RPC_BASE_PATH", + "VLLM_USE_MODELSCOPE", + "VLLM_RINGBUFFER_WARNING_INTERVAL", + "VLLM_DEBUG_DUMP_PATH", + "VLLM_PORT", + "VLLM_CACHE_ROOT", + "LD_LIBRARY_PATH", + "VLLM_SERVER_DEV_MODE", + "VLLM_DP_MASTER_IP", + "VLLM_DP_MASTER_PORT", + "VLLM_RANDOMIZE_DP_DUMMY_INPUTS", + "VLLM_CI_USE_S3", + "VLLM_MODEL_REDIRECT_PATH", + "VLLM_HOST_IP", + "VLLM_FORCE_AOT_LOAD", + "S3_ACCESS_KEY_ID", + "S3_SECRET_ACCESS_KEY", + "S3_ENDPOINT_URL", + "VLLM_USAGE_STATS_SERVER", + "VLLM_NO_USAGE_STATS", + "VLLM_DO_NOT_TRACK", + "VLLM_LOGGING_LEVEL", + "VLLM_LOGGING_PREFIX", + "VLLM_LOGGING_STREAM", + "VLLM_LOGGING_CONFIG_PATH", + "VLLM_LOGGING_COLOR", + "VLLM_LOG_STATS_INTERVAL", + "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", + "VLLM_TUNED_CONFIG_FOLDER", + "VLLM_ENGINE_ITERATION_TIMEOUT_S", + "VLLM_HTTP_TIMEOUT_KEEP_ALIVE", + "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", + "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", + "VLLM_SLEEP_WHEN_IDLE", + "VLLM_IMAGE_FETCH_TIMEOUT", + "VLLM_VIDEO_FETCH_TIMEOUT", + "VLLM_AUDIO_FETCH_TIMEOUT", + "VLLM_MEDIA_URL_ALLOW_REDIRECTS", + "VLLM_MEDIA_LOADING_THREAD_COUNT", + "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", + "VLLM_VIDEO_LOADER_BACKEND", + "VLLM_MEDIA_CONNECTOR", + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", + "VLLM_ASSETS_CACHE", + "VLLM_ASSETS_CACHE_MODEL_CLEAN", + "VLLM_WORKER_MULTIPROC_METHOD", + "VLLM_ENABLE_V1_MULTIPROCESSING", + "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", + "VLLM_CPU_KVCACHE_SPACE", + "VLLM_CPU_OMP_THREADS_BIND", + "VLLM_CPU_NUM_OF_RESERVED_CPU", + "VLLM_CPU_MOE_PREPACK", + "VLLM_CPU_SGL_KERNEL", + "VLLM_TEST_FORCE_LOAD_FORMAT", + "LOCAL_RANK", + "CUDA_VISIBLE_DEVICES", + "NO_COLOR", + } + + from vllm.config.utils import normalize_value + + factors: dict[str, object] = {} + for factor, getter in environment_variables.items(): + if factor in ignored_factors: + continue + + try: + raw = getter() + except Exception as exc: # pragma: no cover - defensive logging + logger.warning( + "Skipping environment variable %s while hashing compile factors: %s", + factor, + exc, + ) + continue - factors = [environment_variables[key]() for key in environment_variables_to_hash] + factors[factor] = normalize_value(raw) ray_noset_env_vars = [ # Refer to @@ -1660,8 +1807,8 @@ def compute_hash() -> str: "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", "RAY_EXPERIMENTAL_NOSET_RBLN_RT_VISIBLE_DEVICES", ] - factors.extend([os.getenv(var) for var in ray_noset_env_vars]) - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + for var in ray_noset_env_vars: + factors[var] = normalize_value(os.getenv(var)) - return hash_str + return factors diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 4e4c668f831a..03da3e565e49 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -4,8 +4,6 @@ import torch import torch.nn as nn -import triton -import triton.language as tl from packaging import version from vllm import envs @@ -35,6 +33,16 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: and current_platform.is_cuda() ): if envs.VLLM_USE_FLASHINFER_SAMPLER: + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + + capability = current_platform.get_device_capability() + assert capability is not None + if not FlashInferBackend.supports_compute_capability(capability): + capability_str = capability.as_version_str() + raise RuntimeError( + "FlashInfer does not support compute capability " + f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1." + ) # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1. logger.info_once( "Using FlashInfer for top-p & top-k sampling.", @@ -49,15 +57,6 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: ) self.forward = self.forward_native - if envs.VLLM_USE_TRITON_SAMPLER: - if envs.VLLM_USE_FLASHINFER_SAMPLER: - logger.info_once( - "Overriding FlashInfer top-p & top-k sampling with " - "Triton top-p & top-k sampling." - ) - else: - logger.info_once("Using Triton for top-p & top-k sampling.") - self.forward = self.forward_native elif current_platform.is_cpu(): arch = current_platform.get_cpu_architecture() # Fall back to native implementation for POWERPC and RISCV. @@ -71,18 +70,24 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: logprobs_mode not in ("processed_logits", "processed_logprobs") and rocm_aiter_ops.is_enabled() ): - import aiter.ops.sampling # noqa: F401 + try: + import aiter.ops.sampling # noqa: F401 - self.aiter_ops = torch.ops.aiter - logger.info_once( - "Using aiter sampler on ROCm (lazy import, sampling-only)." - ) - self.forward = self.forward_hip + self.aiter_ops = torch.ops.aiter + logger.info_once( + "Using aiter sampler on ROCm (lazy import, sampling-only)." + ) + self.forward = self.forward_hip + except ImportError: + logger.warning_once( + "aiter.ops.sampling is not available on ROCm. " + "Falling back to forward_native implementation." + ) + self.forward = self.forward_native else: self.forward = self.forward_native self.apply_top_k_top_p = apply_top_k_top_p - self.apply_top_k_top_p_triton = apply_top_k_top_p_triton def forward_native( self, @@ -105,22 +110,6 @@ def forward_native( probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators), logits_to_return - def forward_triton( - self, - logits: torch.Tensor, - generators: dict[int, torch.Generator], - k: torch.Tensor | None, - p: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - logits = self.apply_top_k_top_p_triton(logits, k, p) - logits_to_return = None - if self.logprobs_mode == "processed_logits": - logits_to_return = logits - elif self.logprobs_mode == "processed_logprobs": - logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators), logits_to_return - def forward_cuda( self, logits: torch.Tensor, @@ -185,6 +174,8 @@ def forward_hip( k: torch.Tensor | None, p: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + # FIXME: Fix aiter_sampler's accuracy issue and remove this flag + DISABLE_AITER_SAMPLER = True """Optimized ROCm/aiter path (same structure as forward_cuda).""" if (k is None and p is None) or generators: if generators: @@ -197,6 +188,8 @@ def forward_hip( "processed_logits", "processed_logprobs", ), "aiter sampler does not support returning logits/logprobs." + if DISABLE_AITER_SAMPLER: + return self.forward_native(logits, generators, k, p) return self.aiter_sample(logits, k, p, generators), None def aiter_sample( @@ -278,15 +271,13 @@ def apply_top_k_top_p( if p is not None: # Apply top-p. - # Note: Running softmax on "logits_sort" produces different probability - # values compared to running softmax on the original unsorted logits as the - # non-associativity of floating-points yields different sum(exp(logits)). probs_sort = logits_sort.softmax(dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) # at least one top_p_mask[:, -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) + # Re-sort the probabilities. logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) return logits @@ -395,638 +386,3 @@ def _to_tensor_scalar_tuple(x): return (x, 0) else: return (None, x) - - -# fmt: off -_PERCENTILE_TO_STD_TABLE = [ - 2.576, 2.326, 2.054, 1.881, 1.751, - 1.645, 1.555, 1.476, 1.405, 1.341, - 1.282, 1.227, 1.175, 1.126, 1.080, - 1.036, 0.994, 0.954, 0.915, 0.878, - 0.842, 0.806, 0.772, 0.739, 0.706, - 0.674, 0.643, 0.613, 0.583, 0.553, - 0.524, 0.496, 0.468, 0.440, 0.412, - 0.385, 0.358, 0.332, 0.305, 0.279, - 0.253, 0.228, 0.202, 0.176, 0.151, - 0.126, 0.100, 0.075, 0.050, 0.025, - 0.000, -0.025, -0.050, -0.075, -0.100, - -0.126, -0.151, -0.176, -0.202, -0.228, - -0.253, -0.279, -0.305, -0.332, -0.358, - -0.385, -0.412, -0.440, -0.468, -0.496, - -0.524, -0.553, -0.583, -0.613, -0.643, - -0.674, -0.706, -0.739, -0.772, -0.806, - -0.842, -0.878, -0.915, -0.954, -0.994, - -1.036, -1.080, -1.126, -1.175, -1.227, - -1.282, -1.341, -1.405, -1.476, -1.555, - -1.645, -1.751, -1.881, -2.054, -2.326 -] -# fmt: on - - -def apply_top_k_top_p_triton( - logits: torch.Tensor, - k: torch.Tensor | None, - p: torch.Tensor | None, -) -> torch.Tensor: - """ - Uses pivot-based algorithm to filter --> sort - """ - - if k is None and p is None: - return logits - if p is None and k is not None: - return apply_top_k_only_triton(logits, k) - # Fallback to torch for small batch sizes or small vocab sizes for top-p - if logits.shape[0] < 16 or logits.shape[1] < 32768: - return apply_top_k_top_p(logits, k, p) - return apply_top_k_top_p_filtered(logits, k, p) - - -@triton.jit -def _topk_triton_kernel( - LOGITS, - OUTPUT, - PERCENTILE_TO_STD_TABLE, - K, - VOCAB_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - row_id = tl.program_id(0) - NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE - k = tl.load(K + row_id) - - if k != VOCAB_SIZE: - # THERE IS NO DUPLICATE LOGIT MANAGEMENT FOR THIS TOP-K KERNEL - # CURRENT IMPLEMENTATION INCLUDES ALL DUPLICATE LOGITS, - # WHICH MAY RETURN MORE THAN K LOGITS. - # THIS FOLLOWS THE IMPLEMENTATION IN apply_top_k_only(). - - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - OUTPUT_ROW = OUTPUT + row_id * VOCAB_SIZE - search_addr = LOGITS_ROW - search_range = VOCAB_SIZE - search_iters = NUM_TILES - - max_logit = -float("inf") - min_logit = float("inf") - - # Zeroth pass: Compute avg and std from a sample block - # May produce incorrect results if VOCAB_SIZE < BLOCK_SIZE - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - num_valid = tl.sum(mask_n) - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / num_valid - sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_valid - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - percentile = tl.cast(k * 1.6 / VOCAB_SIZE * 100 + 1, tl.uint32) - percentile = tl.minimum(percentile, 99) - sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) - outlier_pivot = avg_logit + sigma * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - - # First pass: compute max and min logits and gather outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) - - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 - ) - num_outliers += num_blk_outliers - write_pos = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(OUTPUT_ROW + write_pos, logits_blk, mask=outlier_mask) - - max_range = max_logit - min_range = min_logit - if num_outliers > k: - max_range = max_logit - min_range = outlier_pivot - search_addr = OUTPUT_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 - ) - - # Second passes: Quaternary search for pivots (nlog_4(n)) - num_iters = 0 - k_pivot = float("inf") - if k == VOCAB_SIZE: - k_pivot = -float("inf") - while k_pivot == float("inf"): - k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range - k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range - k_pivot_2 = (max_range - min_range) * 3.0 / 4.0 + min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) - - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - - # Check if any of the pivots are equal to k - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we update the range - elif k_pivots_num_2 > k: - min_range = k_pivot_2 - elif k_pivots_num_1 > k: - min_range = k_pivot_1 - elif k_pivots_num_0 > k: - min_range = k_pivot_0 - if k_pivots_num_0 < k: - max_range = k_pivot_0 - elif k_pivots_num_1 < k: - max_range = k_pivot_1 - elif k_pivots_num_2 < k: - max_range = k_pivot_2 - - num_iters += 1 - if num_iters >= 32 or tl.abs(min_range - max_range) < 1e-16: - k_pivot = k_pivot_0 - - # Third pass: Apply top-k mask - if k_pivot != -float("inf"): - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n) - mask = (logits_blk > k_pivot) & mask_n - logits_blk = tl.where(mask, logits_blk, -float("inf")) - tl.store(OUTPUT_ROW + offs_n, logits_blk, mask=mask_n) - - -def apply_top_k_only_triton( - logits: torch.Tensor, - k: torch.Tensor, -) -> torch.Tensor: - """ - Apply top-k mask to the logits using Triton. - - The logits tensor will be updated out-of-place. - """ - if k is None: - return logits - - batch_size, vocab_size = logits.shape - NUM_PROGRAMS = batch_size # Non-persistent kernel - BLOCK_SIZE = 8192 - NUM_WARPS = 16 - NUM_STAGES = 3 - output = torch.full(logits.shape, -float("inf"), device=logits.device) - PERCENTILE_TO_STD_TABLE = torch.tensor( - _PERCENTILE_TO_STD_TABLE, device=logits.device - ) - - _topk_triton_kernel[(NUM_PROGRAMS,)]( - logits, - output, - PERCENTILE_TO_STD_TABLE, - k, - vocab_size, - BLOCK_SIZE, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - ) - - return output - - -@triton.jit -def top_k_top_p_filter( - LOGITS, - DO_TOP_K, - K, - P_FIL, - BUFFER, - BATCH_SIZE, - SUM_EXCLUDED_PROBS, - FILTERED_LOGITS, - FILTERED_INDICES, - FILTERED_PROBS, - PERCENTILE_TO_STD_TABLE, - VOCAB_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - - for row_id in tl.range(pid, BATCH_SIZE, num_programs): - LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE - BUFFER_ROW = BUFFER + pid * VOCAB_SIZE - - search_addr = LOGITS_ROW - search_range = VOCAB_SIZE - search_iters = NUM_TILES - - max_logit = -float("inf") - min_logit = float("inf") - - # Zeroth pass: Compute avg and std from a sample block - offs = tl.arange(0, BLOCK_SIZE) - mask_n = offs < VOCAB_SIZE - num_mask = tl.sum(mask_n) - logits_blk = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk) / num_mask - sq_avg_logit = tl.sum(logits_blk * logits_blk) / num_mask - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - - percentile = tl.cast(P_FIL * 1.6 / VOCAB_SIZE * 100 + 1, tl.uint32) - percentile = tl.minimum(percentile, 99) - sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) - outlier_pivot = avg_logit + sigma * std_logit - num_outliers = tl.zeros((), dtype=tl.uint32) - - # First pass: compute max and min logits and gather outliers - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load(search_addr + offs_n, mask=mask_n, other=avg_logit) - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - - max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) - - outlier_mask = (logits_blk > outlier_pivot) & mask_n - num_blk_outliers = tl.sum(outlier_mask) - cumulative_pos = tl.cast( - tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32 - ) - num_outliers += num_blk_outliers - - write_idx = tl.where(outlier_mask, cumulative_pos, -1) - tl.store(BUFFER_ROW + write_idx, logits_blk, mask=outlier_mask) - - k_max_range = max_logit - k_min_range = min_logit - p_fil_max_range = max_logit - p_fil_min_range = min_logit - - if num_outliers > P_FIL: - search_addr = BUFFER_ROW - search_range = tl.cast(num_outliers, tl.int32) - search_iters = tl.cast( - (num_outliers + BLOCK_SIZE - 1) // BLOCK_SIZE, tl.int32 - ) - k_min_range = outlier_pivot - p_fil_min_range = outlier_pivot - - k = tl.load(K + row_id) - - # Second passes: Quaternary search for pivots (nlog_4(n)) - num_iters = 0 - k_pivot = float("inf") - p_fil_pivot = float("inf") - # For duplicate pivot detection - min_larger_p_fil_pivot = float("inf") - num_duplicate_to_remove = tl.zeros((), dtype=tl.uint32) - do_deduplicate = tl.zeros((), dtype=tl.int32) - min_larger_p_fil_pivot_2 = float("inf") - num_min_larger_p_fil_pivot_2 = tl.zeros((), dtype=tl.uint32) - min_larger_p_fil_pivot_1 = float("inf") - num_min_larger_p_fil_pivot_1 = tl.zeros((), dtype=tl.uint32) - min_larger_p_fil_pivot_0 = float("inf") - num_min_larger_p_fil_pivot_0 = tl.zeros((), dtype=tl.uint32) - if k == VOCAB_SIZE: - k_pivot = -float("inf") - if P_FIL == VOCAB_SIZE: - p_fil_pivot = -float("inf") - while k_pivot == float("inf") or p_fil_pivot == float("inf"): - k_pivot_0 = (k_max_range - k_min_range) * 1.0 / 4.0 + k_min_range - k_pivot_1 = (k_max_range - k_min_range) * 2.0 / 4.0 + k_min_range - k_pivot_2 = (k_max_range - k_min_range) * 3.0 / 4.0 + k_min_range - k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - k_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - p_fil_pivot_0 = ( - p_fil_max_range - p_fil_min_range - ) * 1.0 / 4.0 + p_fil_min_range - p_fil_pivot_1 = ( - p_fil_max_range - p_fil_min_range - ) * 2.0 / 4.0 + p_fil_min_range - p_fil_pivot_2 = ( - p_fil_max_range - p_fil_min_range - ) * 3.0 / 4.0 + p_fil_min_range - p_fil_pivots_num_0 = tl.zeros((), dtype=tl.uint32) - p_fil_pivots_num_1 = tl.zeros((), dtype=tl.uint32) - p_fil_pivots_num_2 = tl.zeros((), dtype=tl.uint32) - - if p_fil_pivot == float("inf"): - min_larger_p_fil_pivot_2 = float("inf") - num_min_larger_p_fil_pivot_2 = tl.zeros((), dtype=tl.uint32) - min_larger_p_fil_pivot_1 = float("inf") - num_min_larger_p_fil_pivot_1 = tl.zeros((), dtype=tl.uint32) - min_larger_p_fil_pivot_0 = float("inf") - num_min_larger_p_fil_pivot_0 = tl.zeros((), dtype=tl.uint32) - - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) - - k_pivots_num_0 += tl.sum(logits_blk > k_pivot_0) - k_pivots_num_1 += tl.sum(logits_blk > k_pivot_1) - k_pivots_num_2 += tl.sum(logits_blk > k_pivot_2) - - p_fil_pivots_num_0 += tl.sum(logits_blk > p_fil_pivot_0) - p_fil_pivots_num_1 += tl.sum(logits_blk > p_fil_pivot_1) - p_fil_pivots_num_2 += tl.sum(logits_blk > p_fil_pivot_2) - - if p_fil_pivot == float("inf"): - larger_p_fil_pivot = tl.where( - (logits_blk > p_fil_pivot_2) & mask_n, logits_blk, float("inf") - ) - min_larger_p_fil_pivot_2 = tl.minimum( - min_larger_p_fil_pivot_2, tl.min(larger_p_fil_pivot) - ) - - larger_p_fil_pivot = tl.where( - (logits_blk > p_fil_pivot_1) & mask_n, logits_blk, float("inf") - ) - min_larger_p_fil_pivot_1 = tl.minimum( - min_larger_p_fil_pivot_1, tl.min(larger_p_fil_pivot) - ) - - larger_p_fil_pivot = tl.where( - (logits_blk > p_fil_pivot_0) & mask_n, logits_blk, float("inf") - ) - min_larger_p_fil_pivot_0 = tl.minimum( - min_larger_p_fil_pivot_0, tl.min(larger_p_fil_pivot) - ) - - if p_fil_pivot == float("inf"): - for i in range(0, search_iters): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < search_range - logits_blk = tl.load( - search_addr + offs_n, mask=mask_n, other=-float("inf") - ) - min_larger_p_fil_pivot_mask = ( - tl.abs(logits_blk - min_larger_p_fil_pivot_2) < 1e-12 - ) & mask_n - num_min_larger_p_fil_pivot_2 += tl.sum(min_larger_p_fil_pivot_mask) - - min_larger_p_fil_pivot_mask = ( - tl.abs(logits_blk - min_larger_p_fil_pivot_1) < 1e-12 - ) & mask_n - num_min_larger_p_fil_pivot_1 += tl.sum(min_larger_p_fil_pivot_mask) - - min_larger_p_fil_pivot_mask = ( - tl.abs(logits_blk - min_larger_p_fil_pivot_0) < 1e-12 - ) & mask_n - num_min_larger_p_fil_pivot_0 += tl.sum(min_larger_p_fil_pivot_mask) - - # Check if any of the pivots are equal to k - if k_pivot == float("inf"): - if k_pivots_num_0 == k: - k_pivot = k_pivot_0 - elif k_pivots_num_1 == k: - k_pivot = k_pivot_1 - elif k_pivots_num_2 == k: - k_pivot = k_pivot_2 - # If none of the pivots are equal to k, we update the range - elif k_pivots_num_2 > k: - k_min_range = k_pivot_2 - elif k_pivots_num_1 > k: - k_min_range = k_pivot_1 - elif k_pivots_num_0 > k: - k_min_range = k_pivot_0 - if k_pivots_num_0 < k: - k_max_range = k_pivot_0 - elif k_pivots_num_1 < k: - k_max_range = k_pivot_1 - elif k_pivots_num_2 < k: - k_max_range = k_pivot_2 - - # Check if any of the pivots are equal to P_FIL - if p_fil_pivot == float("inf"): - if p_fil_pivots_num_0 == P_FIL: - p_fil_pivot = p_fil_pivot_0 - elif p_fil_pivots_num_1 == P_FIL: - p_fil_pivot = p_fil_pivot_1 - elif p_fil_pivots_num_2 == P_FIL: - p_fil_pivot = p_fil_pivot_2 - # If none of the pivots are equal to P_FIL, we update the range - elif p_fil_pivots_num_2 > P_FIL: - if p_fil_pivots_num_2 - num_min_larger_p_fil_pivot_2 < P_FIL: - # Duplicate pivot detected - p_fil_pivot = p_fil_pivot_2 - # Number of duplicate pivots to keep in the filtered set - num_duplicate_to_remove = p_fil_pivots_num_2 - P_FIL - min_larger_p_fil_pivot = min_larger_p_fil_pivot_2 - do_deduplicate = 1 - p_fil_min_range = p_fil_pivot_2 - elif p_fil_pivots_num_1 > P_FIL: - p_fil_min_range = p_fil_pivot_1 - if p_fil_pivots_num_1 - num_min_larger_p_fil_pivot_1 < P_FIL: - # Duplicate pivot detected - p_fil_pivot = p_fil_pivot_1 - # Number of duplicate pivots to keep in the filtered set - num_duplicate_to_remove = p_fil_pivots_num_1 - P_FIL - min_larger_p_fil_pivot = min_larger_p_fil_pivot_1 - do_deduplicate = 1 - elif p_fil_pivots_num_0 > P_FIL: - p_fil_min_range = p_fil_pivot_0 - if p_fil_pivots_num_0 - num_min_larger_p_fil_pivot_0 < P_FIL: - # Duplicate pivot detected - p_fil_pivot = p_fil_pivot_0 - # Number of duplicate pivots to keep in the filtered set - num_duplicate_to_remove = p_fil_pivots_num_0 - P_FIL - min_larger_p_fil_pivot = min_larger_p_fil_pivot_0 - do_deduplicate = 1 - if p_fil_pivots_num_0 < P_FIL: - p_fil_max_range = p_fil_pivot_0 - elif p_fil_pivots_num_1 < P_FIL: - p_fil_max_range = p_fil_pivot_1 - elif p_fil_pivots_num_2 < P_FIL: - p_fil_max_range = p_fil_pivot_2 - - num_iters += 1 - if num_iters >= 32 or ( - (tl.abs(k_min_range - k_max_range) < 1e-16 and k_pivot != float("inf")) - and ( - tl.abs(p_fil_min_range - p_fil_max_range) < 1e-16 - and p_fil_pivot != float("inf") - ) - ): - if k_pivot == float("inf"): - k_pivot = k_pivot_0 - if p_fil_pivot == float("inf"): - p_fil_pivot = p_fil_pivot_0 - - # Third pass: Calculate exp logits and sum with top-k mask - if not DO_TOP_K or k == VOCAB_SIZE: - k_pivot = -float("inf") - - sum_exp_logits = tl.zeros((), dtype=tl.float32) - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) - - top_k_mask = logits_blk > k_pivot - logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) - - probs_blk = logits_blk - max_logit - probs_blk = tl.exp(probs_blk) - sum_exp_logits += tl.sum(probs_blk) - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - - # Fourth pass: Calculate softmax - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - probs_blk = probs_blk / sum_exp_logits - tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - - # Fifth pass : Gather filtered values with top-k mask - write_pos = tl.zeros((), dtype=tl.int32) - sum_excluded_probs = tl.zeros((), dtype=tl.float32) - num_duplicates_removed = tl.zeros((), dtype=tl.uint32) - FILTERED_LOGITS_ROW = FILTERED_LOGITS + row_id * P_FIL - FILTERED_INDICES_ROW = FILTERED_INDICES + row_id * P_FIL - FILTERED_PROBS_ROW = FILTERED_PROBS + row_id * P_FIL - for i in range(0, NUM_TILES): - offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask_n = offs_n < VOCAB_SIZE - logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) - probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) - - keep_mask = logits_blk > p_fil_pivot - if do_deduplicate > 0: - duplicate_mask = ( - tl.abs(logits_blk - min_larger_p_fil_pivot) < 1e-12 - ) & mask_n - - duplicate_count = tl.cumsum(duplicate_mask) + num_duplicates_removed - duplicate_remove_mask = duplicate_mask & ( - duplicate_count <= num_duplicate_to_remove - ) - keep_mask = keep_mask & (~duplicate_remove_mask) - num_duplicates_removed += tl.sum(duplicate_remove_mask) - - keep_mask = keep_mask & mask_n - cpos = tl.cumsum(keep_mask) - 1 + write_pos - f_mask = keep_mask - write_idx = tl.where(f_mask, cpos, P_FIL) - - top_k_mask = (logits_blk > k_pivot) & mask_n - logits_blk = tl.where(top_k_mask, logits_blk, -float("inf")) - - # Gather filtered values - tl.store(FILTERED_LOGITS_ROW + write_idx, logits_blk, mask=f_mask) - tl.store(FILTERED_INDICES_ROW + write_idx, offs_n, mask=f_mask) - tl.store(FILTERED_PROBS_ROW + write_idx, probs_blk, mask=f_mask) - - sum_excluded_probs += tl.sum(probs_blk * (keep_mask & (~f_mask) & mask_n)) - write_pos += tl.sum(f_mask, dtype=tl.int32) - tl.store(SUM_EXCLUDED_PROBS + row_id, sum_excluded_probs) - - -def apply_top_k_top_p_filtered( - logits: torch.Tensor, - k: torch.Tensor, - p: torch.Tensor, -) -> torch.Tensor: - """ - Applies top p using pivot based filtering - """ - batch_size, vocab_size = logits.shape - - # If k is too large, speedup is not significant as the filtered set is large. - max_k = k.max().item() if k is not None else 0 - - # Our softmax result is different from the original PyTorch top-p implementation - # which runs softmax after a sort compared to our softmax result which runs - # softmax on the original unsorted logits, yielding different sum(exp(logits)) - # values due to the non-associativity of floating-points. - # If p is too large, the top-p cutoff falls in the tail section of the distribution, - # which consists of very small probabilities which has larger relative errors - # compared to the original PyTorch top-p probabilities. As such, we fallback to - # the original PyTorch top-p implementation for accuracy when p is too large. - if max_k > vocab_size / 4 or (k is None and p.max().item() > 0.99): - return apply_top_k_top_p(logits, k, p) - - BLOCK_SIZE = 8192 - device_prop = torch.cuda.get_device_properties(logits.device) - NUM_PROGRAMS = device_prop.multi_processor_count # Persistent kernel - NUM_WARPS = 16 - NUM_STAGES = 3 - buffer = torch.empty( - (NUM_PROGRAMS, vocab_size), device=logits.device, dtype=torch.float32 - ) - p_filter = ( - min(int(max_k * 1.5), vocab_size - 1) if k is not None else int(vocab_size / 32) - ) - filtered_logits = torch.full( - (batch_size, p_filter), -float("inf"), device=logits.device - ) - filtered_indices = torch.full( - (batch_size, p_filter), p_filter, dtype=torch.int64, device=logits.device - ) - filtered_probs = torch.full((batch_size, p_filter), 0.0, device=logits.device) - sum_excluded_probs = torch.zeros( - (batch_size,), device=logits.device, dtype=torch.float32 - ) - PERCENTILE_TO_STD_TABLE = torch.tensor( - _PERCENTILE_TO_STD_TABLE, device=logits.device - ) - - top_k_top_p_filter[(NUM_PROGRAMS,)]( - logits, - (k is not None), - k if k is not None else filtered_indices, - p_filter, - buffer, - batch_size, - sum_excluded_probs, - filtered_logits, - filtered_indices, - filtered_probs, - PERCENTILE_TO_STD_TABLE, - VOCAB_SIZE=vocab_size, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - ) - - if torch.any(sum_excluded_probs >= p): - return apply_top_k_top_p(logits, k, p) - - logits_sort, sort_indices = filtered_logits.sort(dim=-1, descending=False) - logits_sort_indices = torch.gather(filtered_indices, -1, sort_indices) - sorted_probs = torch.gather(filtered_probs, -1, sort_indices) - - sorted_probs[:, 0] = sorted_probs[:, 0] + sum_excluded_probs - probs_sum = torch.cumsum(sorted_probs, dim=-1) - top_p_mask = probs_sum <= (1 - p.unsqueeze(dim=-1)) - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - logits.fill_(-float("inf")) - logits.scatter_(dim=1, index=logits_sort_indices, src=logits_sort) - return logits From 37f322a62311f8a1a0241e8e4b3996f5c64c233d Mon Sep 17 00:00:00 2001 From: Jongseok Park <37990712+cakeng@users.noreply.github.com> Date: Sun, 1 Feb 2026 20:56:01 -0800 Subject: [PATCH 89/99] Update vllm/v1/sample/ops/topk_topp_triton.py Bugfix Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Jongseok Park <37990712+cakeng@users.noreply.github.com> --- vllm/v1/sample/ops/topk_topp_triton.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index cb807f4489bf..3ce720ae8126 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -201,7 +201,7 @@ def _topk_topp_kernel( # Check if any of the pivots satisfy termination condition if ( k_pivots_num_0 >= k - and k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k + and k_pivots_num_0 - num_min_larger_0 < k ): k_pivot = k_pivot_0 k_pivots_num = k_pivots_num_0 @@ -209,7 +209,7 @@ def _topk_topp_kernel( num_min_larger = num_min_larger_0 if ( k_pivots_num_1 >= k - and k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k + and k_pivots_num_1 - num_min_larger_1 < k ): k_pivot = k_pivot_1 k_pivots_num = k_pivots_num_1 From cb731c57c5e42ba9e4c6f621d409781d8bf1db99 Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 1 Feb 2026 21:00:24 -0800 Subject: [PATCH 90/99] Bugfix Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_triton.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 3ce720ae8126..4a14baf9b235 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -276,7 +276,7 @@ def _topk_topp_kernel( # Check if any of the pivots satisfy termination condition if ( k_pivots_num_0 >= k - and k_pivots_num_0 - (min_larger_0 * num_min_larger_0) < k + and k_pivots_num_0 - num_min_larger_0 < k ): k_pivot = k_pivot_0 k_pivots_num = k_pivots_num_0 @@ -284,7 +284,7 @@ def _topk_topp_kernel( num_min_larger = num_min_larger_0 if ( k_pivots_num_1 >= k - and k_pivots_num_1 - (min_larger_1 * num_min_larger_1) < k + and k_pivots_num_1 - num_min_larger_1 < k ): k_pivot = k_pivot_1 k_pivots_num = k_pivots_num_1 From 0c61b95c233d37cb25c7e08ce19e54a4758a8c17 Mon Sep 17 00:00:00 2001 From: Jongseok Park <37990712+cakeng@users.noreply.github.com> Date: Mon, 2 Feb 2026 15:20:09 -0800 Subject: [PATCH 91/99] Refactor comments for clarity in topk_topp_triton.py Signed-off-by: Jongseok Park <37990712+cakeng@users.noreply.github.com> --- vllm/v1/sample/ops/topk_topp_triton.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 4a14baf9b235..4ea4d85ad35b 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -140,7 +140,7 @@ def _topk_topp_kernel( write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) - # Second passes: Quaternary search for pivots (nlog_4(batch_size)) + # Second passes: Ternary search for pivots num_iters = 0 k_pivot = float("inf") k_pivots_num = tl.zeros((), dtype=tl.uint32) @@ -226,6 +226,7 @@ def _topk_topp_kernel( max_range = k_pivot_0 elif k_pivots_num_1 < k: max_range = k_pivot_1 + num_iters += 1 if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: k_pivot = (max_range + min_range) / 2.0 @@ -343,7 +344,7 @@ def _topk_topp_kernel( outlier_mask = (probs_blk > min_logit) & mask_n_2 - # Duplicate logit handling + # Duplicate logit handling for Top-k if num_keep < num_duplicate_logit: duplicate_mask = ( tl.abs(probs_blk - duplicate_logit) < 1e-9 @@ -402,7 +403,7 @@ def _topk_topp_kernel( outlier_mask = (probs_blk > min_logit) & mask_n - # Duplicate logit handling + # Duplicate logit handling for Top-k duplicate_mask = ( tl.abs(probs_blk - duplicate_logit) < 1e-9 ) @@ -563,6 +564,7 @@ def _topk_topp_kernel( # Top-k + Top-p path final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + if TOPP_ENABLED and final_pivot == -float("inf"): #### STANDALONE TOP-P SAMPLING #### p = tl.load(P + row_id) @@ -841,7 +843,7 @@ def _topk_topp_kernel( # Top-p only path final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample - # Sixth pass: Apply mask + # Sixth pass: Apply mask and store final output if final_pivot != -float("inf"): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -919,6 +921,7 @@ def apply_top_k_top_p_triton( num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count NUM_PROGRAMS = min(num_sm, batch_size) + # Cache per-Triton Program buffer on each device. buf_key = (logits.device, logits.dtype, NUM_PROGRAMS, vocab_size) buffer = _TRITON_BUFFER_CACHE.get(buf_key) if buffer is None or buffer.numel() < NUM_PROGRAMS * vocab_size: @@ -927,7 +930,7 @@ def apply_top_k_top_p_triton( ) _TRITON_BUFFER_CACHE[buf_key] = buffer - # Cache percentile table per device. + # Cache lookup table entries on each device. tbl_key = (logits.device, torch.float32) tables = _TRITON_TABLE_CACHE.get(tbl_key) if tables is None: From 503f0b0fcbceda1ef91fc7a7d0feb6e3b5b8dc01 Mon Sep 17 00:00:00 2001 From: js_park Date: Mon, 2 Feb 2026 15:38:37 -0800 Subject: [PATCH 92/99] Pre-commit fix Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_triton.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 4ea4d85ad35b..59fe56c81db9 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -226,7 +226,7 @@ def _topk_topp_kernel( max_range = k_pivot_0 elif k_pivots_num_1 < k: max_range = k_pivot_1 - + num_iters += 1 if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: k_pivot = (max_range + min_range) / 2.0 @@ -564,7 +564,7 @@ def _topk_topp_kernel( # Top-k + Top-p path final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit - + if TOPP_ENABLED and final_pivot == -float("inf"): #### STANDALONE TOP-P SAMPLING #### p = tl.load(P + row_id) From 576f90ebb53bfd4feb2725b054921c27b59a69fc Mon Sep 17 00:00:00 2001 From: js_park Date: Sun, 8 Feb 2026 06:05:33 -0800 Subject: [PATCH 93/99] Update arxiv Signed-off-by: js_park --- vllm/v1/sample/ops/topk_topp_triton.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 59fe56c81db9..a60da2887dbe 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -5,6 +5,7 @@ Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs using Pivot-based Truncation and Selection" By Park et al. +(https://arxiv.org/abs/2602.01518) """ From c18fe71eff9d1ef2a18102b374b0057f8914bcae Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 12 Feb 2026 15:45:31 -0800 Subject: [PATCH 94/99] adjust prob distribution in benchmark, adjust threshold Signed-off-by: Nick Hill --- benchmarks/benchmark_topk_topp.py | 27 +++++++++++++++++++++---- vllm/v1/sample/ops/topk_topp_sampler.py | 13 ++++-------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_topk_topp.py b/benchmarks/benchmark_topk_topp.py index 93a4bd485316..cac332a099d8 100644 --- a/benchmarks/benchmark_topk_topp.py +++ b/benchmarks/benchmark_topk_topp.py @@ -69,8 +69,27 @@ def calculate_ops_pct( def create_logits( batch_size: int, vocab_size: int, device: str = "cuda" ) -> torch.Tensor: - """Create random logits tensor.""" - return torch.randn(batch_size, vocab_size, dtype=torch.float32, device=device) + """Create random logits mimicking a realistic LLM distribution. + + Uses a Zipf-like probability distribution (rank^-1.1) converted to logits + via log, then randomly permuted per row. This produces a peaked distribution + where a small number of tokens capture most probability mass, similar to + real model outputs. + """ + # Create Zipf-like probabilities: p(rank) ~ rank^(-alpha) + ranks = torch.arange(1, vocab_size + 1, dtype=torch.float32, device=device) + probs = ranks.pow(-1.1) + probs = probs / probs.sum() + + # Convert to logits (log-probabilities, unnormalized is fine) + base_logits = probs.log() + + # Broadcast to batch and randomly permute each row + logits = base_logits.unsqueeze(0).expand(batch_size, -1).clone() + for i in range(batch_size): + logits[i] = logits[i, torch.randperm(vocab_size, device=device)] + + return logits def measure_memory() -> tuple[int, int]: @@ -234,7 +253,7 @@ def create_benchmark_configs( third = batch_size // 3 k_mixed[:third] = 50 # Second third: p only - p_mixed[third : 2 * third] = 0.9 + p_mixed[third : 2 * third] = 0.5 # Last third: both k and p k_mixed[2 * third :] = 100 p_mixed[2 * third :] = 0.9 @@ -384,7 +403,7 @@ def main(): "--batch-sizes", type=int, nargs="+", - default=[1, 4, 16, 24, 32, 48, 56, 64, 96, 128, 192, 256, 512, 1024], + default=[1, 4, 16, 64, 128, 512, 1024, 2048], help="Batch sizes to test (default: 1 4 16 64)", ) parser.add_argument( diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index eaf6dcc3c520..33f7090e4e3d 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -248,15 +248,10 @@ def apply_top_k_top_p( if p is None and k is None: return logits - # Rough empirical heuristic - if HAS_TRITON: - batch_size, vocab_size = logits.shape - both_k_and_p = p is not None and k is not None - threshold = vocab_size // (1024 if both_k_and_p else 2048) - if batch_size >= threshold: - # Use pytorch sort implementation for smaller batch sizes. - return apply_top_k_top_p_triton(logits, k, p) + if HAS_TRITON and logits.shape[0] >= 8: + return apply_top_k_top_p_triton(logits, k, p) + # Use pytorch sort implementation for small batch sizes. return apply_top_k_top_p_pytorch(logits, k, p) @@ -301,7 +296,7 @@ def apply_top_k_top_p_pytorch( logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. - return logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits.scatter_(dim=-1, index=logits_idx, src=logits_sort) def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor: From c246c3a75a37fcd947d99e6e0ab40fba03a1828e Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 12 Feb 2026 15:50:30 -0800 Subject: [PATCH 95/99] some simplification/cleanup Signed-off-by: Nick Hill --- vllm/utils/math_utils.py | 8 ++---- vllm/v1/sample/ops/topk_topp_triton.py | 36 ++++++++++---------------- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/vllm/utils/math_utils.py b/vllm/utils/math_utils.py index 5fc6c3d664f0..a0e301af471f 100644 --- a/vllm/utils/math_utils.py +++ b/vllm/utils/math_utils.py @@ -14,16 +14,12 @@ def cdiv(a: int, b: int) -> int: def next_power_of_2(n: int) -> int: """The next power of 2 (inclusive)""" - if n < 1: - return 1 - return 1 << (n - 1).bit_length() + return 1 if n < 1 else 1 << (n - 1).bit_length() def prev_power_of_2(n: int) -> int: """The previous power of 2 (inclusive)""" - if n <= 0: - return 0 - return 1 << (n.bit_length() - 1) + return 0 if n <= 0 else 1 << (n.bit_length() - 1) def round_up(x: int, y: int) -> int: diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index a60da2887dbe..d225ef618fe3 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -12,13 +12,10 @@ import torch from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import next_power_of_2 -_TRITON_TABLE_CACHE: dict[ - tuple[torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor] -] = {} -_TRITON_BUFFER_CACHE: dict[ - tuple[torch.device, torch.dtype, int, int], torch.Tensor -] = {} +_TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {} +_TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {} # fmt: off _NORMAL_CDF_TO_SIGMA_TABLE = [ @@ -384,7 +381,6 @@ def _topk_topp_kernel( other=-float("inf"), ) - outlier_mask = (probs_blk > min_logit) & mask_n_2 probs_blk = probs_blk - max_logit probs_blk = tl.exp(probs_blk) probs_blk = probs_blk / sum_exp_logits @@ -898,8 +894,6 @@ def apply_top_k_top_p_triton( assert logits.is_cuda batch_size, vocab_size = logits.shape - num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count - NUM_PROGRAMS = min(num_sm, batch_size) topk_enabled = k is not None topp_enabled = p is not None @@ -923,25 +917,21 @@ def apply_top_k_top_p_triton( NUM_PROGRAMS = min(num_sm, batch_size) # Cache per-Triton Program buffer on each device. - buf_key = (logits.device, logits.dtype, NUM_PROGRAMS, vocab_size) + buf_key = (logits.device, logits.dtype, vocab_size) buffer = _TRITON_BUFFER_CACHE.get(buf_key) - if buffer is None or buffer.numel() < NUM_PROGRAMS * vocab_size: - buffer = torch.empty( - (NUM_PROGRAMS, vocab_size), device=logits.device, dtype=logits.dtype - ) + if buffer is None or buffer.shape[0] < NUM_PROGRAMS: + size = min(next_power_of_2(NUM_PROGRAMS), num_sm) + buffer = logits.new_empty((size, vocab_size)) _TRITON_BUFFER_CACHE[buf_key] = buffer + if NUM_PROGRAMS < buffer.shape[0]: + buffer = buffer[:NUM_PROGRAMS] # Cache lookup table entries on each device. - tbl_key = (logits.device, torch.float32) - tables = _TRITON_TABLE_CACHE.get(tbl_key) + tables = _TRITON_TABLE_CACHE.get(logits.device) if tables is None: - normal_cdf_to_sigma_table = torch.tensor( - _NORMAL_CDF_TO_SIGMA_TABLE, device=logits.device, dtype=torch.float32 - ) - percentile_to_std_table = torch.tensor( - _PERCENTILE_TO_STD_TABLE, device=logits.device, dtype=torch.float32 - ) - _TRITON_TABLE_CACHE[tbl_key] = ( + normal_cdf_to_sigma_table = logits.new_tensor(_NORMAL_CDF_TO_SIGMA_TABLE) + percentile_to_std_table = logits.new_tensor(_PERCENTILE_TO_STD_TABLE) + _TRITON_TABLE_CACHE[logits.device] = ( normal_cdf_to_sigma_table, percentile_to_std_table, ) From 4360e923845031dab2aa4f71f22ef8b329e2061c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 12 Feb 2026 18:06:18 -0800 Subject: [PATCH 96/99] fix precommit Signed-off-by: Nick Hill --- vllm/v1/sample/ops/topk_topp_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index d225ef618fe3..0707e424807b 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -923,7 +923,7 @@ def apply_top_k_top_p_triton( size = min(next_power_of_2(NUM_PROGRAMS), num_sm) buffer = logits.new_empty((size, vocab_size)) _TRITON_BUFFER_CACHE[buf_key] = buffer - if NUM_PROGRAMS < buffer.shape[0]: + if buffer.shape[0] > NUM_PROGRAMS: buffer = buffer[:NUM_PROGRAMS] # Cache lookup table entries on each device. From b917a49564508b185fab0d562b6c747fbc79563a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 15 Feb 2026 16:25:00 -0800 Subject: [PATCH 97/99] fix -inf edge cases and possible infinite loop Signed-off-by: Nick Hill --- tests/v1/sample/test_topk_topp_sampler.py | 273 ++++++++++++++++++++++ vllm/v1/sample/ops/topk_topp_triton.py | 132 ++++++++--- 2 files changed, 377 insertions(+), 28 deletions(-) diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index fc6412949dd9..ce1e288a2418 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -296,3 +296,276 @@ def test_large_batch(self): p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5 self._compare_results(logits, k, p) + + # ----------------------------------------------------------------- + # Tests for -inf logits (e.g. from grammar / structured output masks) + # ----------------------------------------------------------------- + + @pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99]) + def test_topk_with_neginf_logits(self, inf_fraction: float): + """Top-k with many -inf logits (simulating grammar bitmask). + + The kernel must not produce NaN when most logits are -inf, which + can happen when structured-output grammar masks are applied before + sampling. + """ + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + batch_size, vocab_size = 32, 128256 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + # Mask a fraction of logits to -inf. + mask = ( + torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction + ) + logits[mask] = float("-inf") + + k = torch.randint( + 1, 50, (batch_size,), generator=self.generator, dtype=torch.int32 + ) + result = apply_top_k_top_p_triton(logits.clone(), k, None) + + assert not result.isnan().any(), "NaN found in top-k result with -inf logits" + for i in range(batch_size): + kept = (result[i] > float("-inf")).sum().item() + assert kept <= k[i].item(), f"Row {i}: kept {kept} > k={k[i].item()}" + # At least one value should survive unless the row was all -inf. + finite_in = (logits[i] > float("-inf")).sum().item() + if finite_in > 0: + assert kept > 0, f"Row {i}: no tokens kept despite finite input" + + @pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99]) + def test_topp_with_neginf_logits(self, inf_fraction: float): + """Top-p with many -inf logits.""" + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + batch_size, vocab_size = 32, 128256 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + mask = ( + torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction + ) + logits[mask] = float("-inf") + + p = ( + torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9 + + 0.1 + ) + result = apply_top_k_top_p_triton(logits.clone(), None, p) + + assert not result.isnan().any(), "NaN found in top-p result with -inf logits" + for i in range(batch_size): + finite_in = (logits[i] > float("-inf")).sum().item() + kept = (result[i] > float("-inf")).sum().item() + if finite_in > 0: + assert kept > 0, f"Row {i}: no tokens kept despite finite input" + + @pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99]) + def test_topk_topp_with_neginf_logits(self, inf_fraction: float): + """Combined top-k + top-p with many -inf logits.""" + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + batch_size, vocab_size = 32, 128256 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + mask = ( + torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction + ) + logits[mask] = float("-inf") + + k = torch.randint( + 1, 50, (batch_size,), generator=self.generator, dtype=torch.int32 + ) + p = ( + torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9 + + 0.1 + ) + result = apply_top_k_top_p_triton(logits.clone(), k, p) + + assert not result.isnan().any(), ( + "NaN found in top-k+top-p result with -inf logits" + ) + for i in range(batch_size): + kept = (result[i] > float("-inf")).sum().item() + assert kept <= k[i].item(), f"Row {i}: kept {kept} > k={k[i].item()}" + + def test_all_neginf_logits(self): + """All logits are -inf (fully masked). Kernel should be a no-op.""" + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + batch_size, vocab_size = 16, 128256 + logits = torch.full( + (batch_size, vocab_size), float("-inf"), dtype=torch.float32 + ) + + k = torch.randint( + 1, 50, (batch_size,), generator=self.generator, dtype=torch.int32 + ) + p = torch.full((batch_size,), 0.9, dtype=torch.float32) + + # top-k only + result = apply_top_k_top_p_triton(logits.clone(), k, None) + assert not result.isnan().any(), "NaN from all-inf top-k" + assert (result == float("-inf")).all(), "Expected all -inf unchanged" + + # top-p only + result = apply_top_k_top_p_triton(logits.clone(), None, p) + assert not result.isnan().any(), "NaN from all-inf top-p" + assert (result == float("-inf")).all(), "Expected all -inf unchanged" + + # top-k + top-p + result = apply_top_k_top_p_triton(logits.clone(), k, p) + assert not result.isnan().any(), "NaN from all-inf top-k+top-p" + assert (result == float("-inf")).all(), "Expected all -inf unchanged" + + def test_few_valid_tokens_with_neginf(self): + """Only a handful of tokens are finite per row (strict grammar).""" + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + batch_size, vocab_size = 32, 128256 + logits = torch.full( + (batch_size, vocab_size), float("-inf"), dtype=torch.float32 + ) + # Allow only 5 random tokens per row to be finite. + for i in range(batch_size): + indices = torch.randperm(vocab_size, generator=self.generator)[:5] + logits[i, indices] = torch.randn( + 5, generator=self.generator, dtype=torch.float32 + ) + + k = torch.full((batch_size,), 50, dtype=torch.int32) + p = torch.full((batch_size,), 0.9, dtype=torch.float32) + + # top-k only (k=50 but only 5 finite → keep all 5) + result = apply_top_k_top_p_triton(logits.clone(), k, None) + assert not result.isnan().any() + for i in range(batch_size): + kept = (result[i] > float("-inf")).sum().item() + assert kept == 5, f"Row {i}: expected 5 kept, got {kept}" + + # top-k with k < num_finite + k_small = torch.full((batch_size,), 3, dtype=torch.int32) + result = apply_top_k_top_p_triton(logits.clone(), k_small, None) + assert not result.isnan().any() + for i in range(batch_size): + kept = (result[i] > float("-inf")).sum().item() + assert kept <= 3, f"Row {i}: expected <=3 kept, got {kept}" + + # top-p only + result = apply_top_k_top_p_triton(logits.clone(), None, p) + assert not result.isnan().any() + for i in range(batch_size): + kept = (result[i] > float("-inf")).sum().item() + assert kept > 0, f"Row {i}: no tokens kept" + + @pytest.mark.parametrize("num_valid", [1, 2, 5, 10, 50]) + @pytest.mark.parametrize( + "mode", + ["topk_only", "topp_only", "topk_and_topp"], + ) + def test_equal_logits_few_valid(self, num_valid: int, mode: str): + """Few valid tokens all sharing the same logit value. + + This is the pattern produced by grammar bitmask filtering when + the model assigns similar scores to the few allowed tokens. + The ternary search can converge to a pivot equal to max_logit, + causing the strict `>` keep_mask to exclude everything. + Regression test for the `final_pivot >= max_logit` guard. + """ + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + batch_size, vocab_size = 32, 128256 + logits = torch.full( + (batch_size, vocab_size), float("-inf"), dtype=torch.float32 + ) + # Set exactly `num_valid` tokens per row to the SAME finite value. + for i in range(batch_size): + indices = torch.randperm(vocab_size, generator=self.generator)[:num_valid] + logits[i, indices] = 1.0 # all equal + + k: torch.Tensor | None = None + p: torch.Tensor | None = None + if mode in ("topk_only", "topk_and_topp"): + k = torch.full((batch_size,), max(1, num_valid - 1), dtype=torch.int32) + if mode in ("topp_only", "topk_and_topp"): + p = torch.full((batch_size,), 0.95, dtype=torch.float32) + + result = apply_top_k_top_p_triton(logits.clone(), k, p) + + assert not result.isnan().any(), "NaN in equal-logit result" + for i in range(batch_size): + kept = (result[i] > float("-inf")).sum().item() + # The key invariant: at least one token must survive. + # With all-equal logits the pivot search can't differentiate + # tokens, so the guard may keep more than k — that is the + # intended safe fallback. + assert kept > 0, ( + f"Row {i}: all tokens masked with {num_valid} equal-valued " + f"finite logits ({mode})" + ) + + @pytest.mark.parametrize("num_valid", [2, 5, 10]) + def test_nearly_equal_logits_topp(self, num_valid: int): + """Few valid tokens with very similar (but not identical) logits. + + Ensures the kernel handles near-degenerate probability + distributions where the ternary search range collapses. + """ + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + batch_size, vocab_size = 32, 128256 + logits = torch.full( + (batch_size, vocab_size), float("-inf"), dtype=torch.float32 + ) + for i in range(batch_size): + indices = torch.randperm(vocab_size, generator=self.generator)[:num_valid] + # Tiny spread: values in [1.0, 1.0 + 1e-6] + logits[i, indices] = ( + 1.0 + + torch.rand(num_valid, generator=self.generator, dtype=torch.float32) + * 1e-6 + ) + + p = torch.full((batch_size,), 0.95, dtype=torch.float32) + result = apply_top_k_top_p_triton(logits.clone(), None, p) + + assert not result.isnan().any(), "NaN in nearly-equal-logit result" + for i in range(batch_size): + kept = (result[i] > float("-inf")).sum().item() + assert kept > 0, ( + f"Row {i}: all tokens masked with {num_valid} " + f"nearly-equal finite logits" + ) + + def test_mixed_neginf_and_normal_rows(self): + """Batch with a mix of normal rows and heavily-masked rows.""" + from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton + + batch_size, vocab_size = 32, 32000 + logits = torch.randn( + batch_size, vocab_size, generator=self.generator, dtype=torch.float32 + ) + # Mask even rows heavily (99% -inf), leave odd rows normal. + for i in range(0, batch_size, 2): + mask = torch.rand(vocab_size, generator=self.generator) < 0.99 + logits[i][mask] = float("-inf") + + k = torch.randint( + 1, 50, (batch_size,), generator=self.generator, dtype=torch.int32 + ) + p = ( + torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9 + + 0.1 + ) + + result = apply_top_k_top_p_triton(logits.clone(), k, p) + assert not result.isnan().any(), "NaN in mixed normal/-inf batch" + for i in range(batch_size): + kept = (result[i] > float("-inf")).sum().item() + assert kept <= k[i].item() + finite_in = (logits[i] > float("-inf")).sum().item() + if finite_in > 0: + assert kept > 0, f"Row {i}: no tokens kept" diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index 0707e424807b..f776e94d6680 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -104,12 +104,25 @@ def _topk_topp_kernel( # Zeroth pass: Compute avg and std from a sample block offs = tl.arange(0, BLOCK_SIZE) mask_n = offs < VOCAB_SIZE - num_valid = tl.sum(mask_n) - logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk0) / num_valid - sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - std_logit = tl.maximum(std_logit, 0.0) + logits_blk0 = tl.load( + LOGITS_ROW + offs, mask=mask_n, other=-float("inf") + ) + # Exclude -inf values (e.g. from grammar bitmasks) from + # statistics to avoid NaN in pivot computation. + finite_mask = (logits_blk0 > -float("inf")) & mask_n + num_finite = tl.sum(finite_mask) + finite_logits = tl.where(finite_mask, logits_blk0, 0.0) + avg_logit = tl.where( + num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0 + ) + sq_avg_logit = tl.where( + num_finite > 0, + tl.sum(finite_logits * finite_logits) / num_finite, + 0.0, + ) + std_logit = tl.sqrt( + tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0) + ) # Calculate outlier pivot t for Gaussian sigma-truncation percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) @@ -120,15 +133,21 @@ def _topk_topp_kernel( num_outliers = tl.zeros((), dtype=tl.uint32) # First pass: compute max and min logits and gather outliers + num_finite_total = tl.zeros((), dtype=tl.uint32) for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") ) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + # Exclude -inf from min to keep binary search bounds + # finite (avoids NaN pivots). + finite_blk_mask = logits_blk > -float("inf") + finite_blk = tl.where(finite_blk_mask, logits_blk, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(finite_blk)) + num_finite_total += tl.sum(finite_blk_mask & mask_n) outlier_mask = (logits_blk > outlier_pivot) & mask_n cumulative_pos = tl.cast( @@ -138,6 +157,10 @@ def _topk_topp_kernel( write_pos = tl.where(outlier_mask, cumulative_pos, -1) tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) + # If no finite logits exist (all -inf), clamp min to + # max so the search converges to -inf (no masking). + min_logit = tl.minimum(min_logit, max_logit) + # Second passes: Ternary search for pivots num_iters = 0 k_pivot = float("inf") @@ -152,7 +175,8 @@ def _topk_topp_kernel( (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, tl.int32, ) - while k_pivot == float("inf"): + found_pivot = 0 + while found_pivot == 0: k_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) min_larger_0 = float("inf") @@ -205,6 +229,7 @@ def _topk_topp_kernel( k_pivots_num = k_pivots_num_0 min_larger = min_larger_0 num_min_larger = num_min_larger_0 + found_pivot = 1 if ( k_pivots_num_1 >= k and k_pivots_num_1 - num_min_larger_1 < k @@ -213,6 +238,7 @@ def _topk_topp_kernel( k_pivots_num = k_pivots_num_1 min_larger = min_larger_1 num_min_larger = num_min_larger_1 + found_pivot = 1 # Update range if k_pivots_num_1 > k: @@ -228,11 +254,13 @@ def _topk_topp_kernel( num_iters += 1 if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: k_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 else: # If top-k outlier gathering failed, search whole logit space max_range = max_logit min_range = min_logit - while k_pivot == float("inf"): + found_pivot = 0 + while found_pivot == 0: k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) min_larger_0 = float("inf") @@ -254,8 +282,13 @@ def _topk_topp_kernel( k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) - min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) - min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) + # Exclude -inf from min_larger to avoid + # poisoning the convergence check. + finite_blk2 = tl.where( + logits_blk2 > -float("inf"), logits_blk2, float("inf") + ) + min_larger_0 = tl.minimum(min_larger_0, tl.min(finite_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(finite_blk2)) # Second pass: Calculate num_min_larger for i in range(0, NUM_TILES): @@ -281,6 +314,7 @@ def _topk_topp_kernel( k_pivots_num = k_pivots_num_0 min_larger = min_larger_0 num_min_larger = num_min_larger_0 + found_pivot = 1 if ( k_pivots_num_1 >= k and k_pivots_num_1 - num_min_larger_1 < k @@ -289,6 +323,7 @@ def _topk_topp_kernel( k_pivots_num = k_pivots_num_1 min_larger = min_larger_1 num_min_larger = num_min_larger_1 + found_pivot = 1 # Update range if k_pivots_num_1 > k: @@ -304,16 +339,18 @@ def _topk_topp_kernel( num_iters += 1 if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: k_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 duplicate_logit = min_larger num_duplicate_logit = num_min_larger num_keep = num_duplicate_logit - (k_pivots_num - k) num_kept = tl.zeros((), dtype=tl.uint32) - # Top-k only path - final_pivot = k_pivot + # Top-k only path. If there are fewer finite values + # than k (e.g. grammar mask), keep everything. + final_pivot = k_pivot if num_finite_total > k else -float("inf") - if TOPP_ENABLED: + if TOPP_ENABLED and num_finite_total > k: #### TOP-P SAMPLING AFTER TOP-K #### p = tl.load(P + row_id) if p < 1.0: @@ -461,7 +498,8 @@ def _topk_topp_kernel( p_pivots_sum = 0.0 # Fifth passes: Search for p_pivot - while p_pivot == 1.0: + found_pivot = 0 + while found_pivot == 0: p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 @@ -527,6 +565,7 @@ def _topk_topp_kernel( min_larger_prob = min_larger_1 num_min_larger = num_min_larger_1 p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 if p_pivots_sum_0 >= p and ( p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p ): @@ -534,6 +573,7 @@ def _topk_topp_kernel( min_larger_prob = min_larger_0 num_min_larger = num_min_larger_0 p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 # Update range if p_pivots_sum_1 > p: @@ -549,6 +589,7 @@ def _topk_topp_kernel( num_iters += 1 if (max_range - min_range) < 1e-9 or num_iters >= 18: p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 duplicate_logit = ( tl.log(min_larger_prob * sum_exp_logits) + max_logit @@ -569,12 +610,25 @@ def _topk_topp_kernel( # Zeroth pass: Compute avg and std from a sample block offs = tl.arange(0, BLOCK_SIZE) mask_n = offs < VOCAB_SIZE - num_valid = tl.sum(mask_n) - logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=0.0) - avg_logit = tl.sum(logits_blk0) / num_valid - sq_avg_logit = tl.sum(logits_blk0 * logits_blk0) / num_valid - std_logit = tl.sqrt(sq_avg_logit - avg_logit * avg_logit) - std_logit = tl.maximum(std_logit, 0.0) + logits_blk0 = tl.load( + LOGITS_ROW + offs, mask=mask_n, other=-float("inf") + ) + # Exclude -inf values (e.g. from grammar bitmasks) from + # statistics to avoid NaN in pivot computation. + finite_mask = (logits_blk0 > -float("inf")) & mask_n + num_finite = tl.sum(finite_mask) + finite_logits = tl.where(finite_mask, logits_blk0, 0.0) + avg_logit = tl.where( + num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0 + ) + sq_avg_logit = tl.where( + num_finite > 0, + tl.sum(finite_logits * finite_logits) / num_finite, + 0.0, + ) + std_logit = tl.sqrt( + tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0) + ) max_sample = avg_logit + std_logit * 10.0 sum_exp_logits = 0.0 @@ -583,15 +637,24 @@ def _topk_topp_kernel( offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE logits_blk = tl.load( - LOGITS_ROW + offs_n, mask=mask_n, other=avg_logit + LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") ) max_logit = tl.maximum(max_logit, tl.max(logits_blk)) - min_logit = tl.minimum(min_logit, tl.min(logits_blk)) + # Exclude -inf from min to keep binary search bounds + # finite (avoids NaN pivots). + finite_blk = tl.where( + logits_blk > -float("inf"), logits_blk, float("inf") + ) + min_logit = tl.minimum(min_logit, tl.min(finite_blk)) probs_blk = tl.exp(logits_blk - max_sample) probs_blk = tl.where(mask_n, probs_blk, 0.0) sum_exp_logits += tl.sum(probs_blk) + # If no finite logits exist (all -inf), clamp min to + # max so the search converges to -inf (no masking). + min_logit = tl.minimum(min_logit, max_logit) + idx = tl.cast(p * 200, tl.int32) idx = tl.maximum(0, tl.minimum(idx, 199)) sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) @@ -640,7 +703,8 @@ def _topk_topp_kernel( tl.int32, ) - while p_pivot == 1.0: + found_pivot = 0 + while found_pivot == 0: p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 @@ -707,6 +771,7 @@ def _topk_topp_kernel( min_larger_prob = min_larger_1 num_min_larger = num_min_larger_1 p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 if ( p_pivots_sum_0 >= p and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p @@ -715,6 +780,7 @@ def _topk_topp_kernel( min_larger_prob = min_larger_0 num_min_larger = num_min_larger_0 p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 # Update range if p_pivots_sum_1 > p: @@ -730,6 +796,7 @@ def _topk_topp_kernel( num_iters += 1 if (max_range - min_range) < 1e-9 or num_iters >= 18: p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 else: # Re-populate the buffer with full softmax probabilities for i in range(0, NUM_TILES): @@ -743,7 +810,8 @@ def _topk_topp_kernel( probs_blk = probs_blk / sum_exp_logits tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) - while p_pivot == 1.0: + found_pivot = 0 + while found_pivot == 0: p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range p_pivots_sum_0 = 0.0 min_larger_0 = 1.0 @@ -806,6 +874,7 @@ def _topk_topp_kernel( min_larger_prob = min_larger_1 num_min_larger = num_min_larger_1 p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 if ( p_pivots_sum_0 >= p and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p @@ -814,6 +883,7 @@ def _topk_topp_kernel( min_larger_prob = min_larger_0 num_min_larger = num_min_larger_0 p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 # Update range if p_pivots_sum_1 > p: @@ -829,6 +899,7 @@ def _topk_topp_kernel( num_iters += 1 if (max_range - min_range) < 1e-9 or num_iters >= 18: p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit num_duplicate_logit = num_min_larger @@ -840,8 +911,13 @@ def _topk_topp_kernel( # Top-p only path final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample - # Sixth pass: Apply mask and store final output - if final_pivot != -float("inf"): + # Sixth pass: Apply mask and store final output. + # If the pivot >= max logit (or is NaN), no token would + # survive the strict `>` keep_mask. Skip masking. + # Using `not <` instead of `>=` so that NaN is also caught. + if not (final_pivot < max_logit): + final_pivot = -float("inf") + elif final_pivot != -float("inf"): for i in range(0, NUM_TILES): offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_n = offs_n < VOCAB_SIZE From 9dadec16b32812ac36b54f9b9c9c2194af38df07 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 17 Feb 2026 09:41:00 -0800 Subject: [PATCH 98/99] add async yield in cancellation test Signed-off-by: Nick Hill --- tests/entrypoints/instrumentator/test_basic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/entrypoints/instrumentator/test_basic.py b/tests/entrypoints/instrumentator/test_basic.py index 1ff30de31bbe..d562b0d849e4 100644 --- a/tests/entrypoints/instrumentator/test_basic.py +++ b/tests/entrypoints/instrumentator/test_basic.py @@ -148,6 +148,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer): ) ) tasks.append(task) + await asyncio.sleep(0.001) done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) From 0e46d904ccc6b12046f0a26da6c6398a95821233 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 17 Feb 2026 11:21:34 -0800 Subject: [PATCH 99/99] use temperature=0 in cancellation test Signed-off-by: Nick Hill --- tests/entrypoints/instrumentator/test_basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/instrumentator/test_basic.py b/tests/entrypoints/instrumentator/test_basic.py index d562b0d849e4..9c2986ebe6c9 100644 --- a/tests/entrypoints/instrumentator/test_basic.py +++ b/tests/entrypoints/instrumentator/test_basic.py @@ -145,10 +145,10 @@ async def test_request_cancellation(server: RemoteOpenAIServer): model=MODEL_NAME, max_tokens=10000, extra_body={"min_tokens": 10000}, + temperature=0.0, ) ) tasks.append(task) - await asyncio.sleep(0.001) done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) @@ -164,7 +164,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer): # be able to respond to this one within the timeout client = server.get_async_client(timeout=5) response = await client.chat.completions.create( - messages=chat_input, model=MODEL_NAME, max_tokens=10 + messages=chat_input, model=MODEL_NAME, max_tokens=10, temperature=0.0 ) assert len(response.choices) == 1