From d904cfa00cf851fb8af3bd151f4f8ec53590d820 Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Wed, 29 Apr 2026 19:05:44 -0700 Subject: [PATCH] Revert "[Perf] Enable FlashInfer top-k/top-p sampler by default (#40376)" This reverts commit b92ef9ec5a041b538f44d9584bef0e34bfbeecd1. --- .buildkite/test_areas/samplers.yaml | 4 +- .../models/language/generation/test_hybrid.py | 2 +- tests/v1/sample/test_topk_topp_sampler.py | 301 ------------------ vllm/envs.py | 8 +- vllm/v1/sample/ops/topk_topp_sampler.py | 44 ++- 5 files changed, 23 insertions(+), 336 deletions(-) diff --git a/.buildkite/test_areas/samplers.yaml b/.buildkite/test_areas/samplers.yaml index 37f8eaa6883c..554203825b49 100644 --- a/.buildkite/test_areas/samplers.yaml +++ b/.buildkite/test_areas/samplers.yaml @@ -11,9 +11,7 @@ steps: - tests/samplers - tests/conftest.py commands: - # VLLM_USE_FLASHINFER_SAMPLER defaults to 1 now, so we need to pin both - # values explicitly to still cover the PyTorch-native (Triton) path. - - VLLM_USE_FLASHINFER_SAMPLER=0 pytest -v -s samplers + - pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers mirror: amd: diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index e410daf2fcdd..01d395b1e0d8 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -881,7 +881,7 @@ def test_apc_common_prefix_same_batch( "hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501 "hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501 ] - sampling_params = SamplingParams(temperature=0.0, max_tokens=20) + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20) outputs = llm.generate(prompts, sampling_params) for output in outputs: assert "two" in output.outputs[0].text diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index 659577b754f6..23f1f1c1f98a 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -13,30 +13,6 @@ VOCAB_SIZE = 128 * 1024 -def _flashinfer_topk_topp_supported() -> bool: - """True iff the FlashInfer top-k/top-p sampler is usable on this host. - - Mirrors the gate in `TopKTopPSampler.__init__`: CUDA + flashinfer - importable + GPU compute capability supported by the FlashInfer - backend. - """ - if not current_platform.is_cuda(): - return False - try: - import flashinfer # noqa: F401 - - from vllm.v1.attention.backends.flashinfer import FlashInferBackend - except ImportError: - return False - capability = current_platform.get_device_capability() - if capability is None: - return False - return FlashInferBackend.supports_compute_capability(capability) - - -FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported() - - @pytest.fixture(autouse=True) def reset_default_device(): """ @@ -592,280 +568,3 @@ def test_mixed_neginf_and_normal_rows(self): finite_in = (logits[i] > float("-inf")).sum().item() if finite_in > 0: assert kept > 0, f"Row {i}: no tokens kept" - - -# ============================================================================= -# FlashInfer top-k/top-p robustness tests -# ============================================================================= - - -@pytest.mark.skipif( - not FLASHINFER_TOPK_TOPP_SUPPORTED, - reason="FlashInfer top-k/top-p sampler requires CUDA " - "and a GPU with FlashInfer support.", -) -class TestFlashInferTopkToppRobustness: - """Robustness of FlashInfer top-k / top-p sampling to NaN / Inf logits. - - The FlashInfer sampler is enabled by default on supported GPUs. A - single poisoned request (NaN / +Inf / -Inf in row 0) must not: - - 1. crash or hang the process; - 2. produce out-of-range token ids (anything outside ``[0, vocab)``); - 3. corrupt other batch rows — neighbours of a poisoned row must - still receive valid token ids (regression for cross-row - corruption in a DP batch where one bad request would otherwise - poison its peers). - - The reference is "no crash + valid token ids", not bit-exact equality - against the PyTorch-native path. - """ - - BATCH = 8 - VOCAB = 32768 - TOPK = 50 - TOPP = 0.9 - - @pytest.fixture(autouse=True) - def setup(self): - torch.set_default_device(DEVICE_TYPE) - self.generator = Generator(device=DEVICE_TYPE).manual_seed(1234) - - def _make_logits(self, pattern: str) -> torch.Tensor: - """Build (BATCH, VOCAB) logits with `pattern` applied to row 0 - (rows 1..B-1 stay clean so we can detect cross-row corruption).""" - logits = ( - torch.randn( - self.BATCH, - self.VOCAB, - generator=self.generator, - dtype=torch.float32, - ) - * 5.0 - ) - if pattern == "clean": - return logits - if pattern == "nan_one_row": - logits[0, :] = float("nan") - elif pattern == "nan_few": - # Scatter 16 NaNs across row 0, keep the rest finite. - idx = torch.randperm(self.VOCAB, generator=self.generator)[:16] - logits[0, idx] = float("nan") - elif pattern == "nan_at_top": - # Poison the top-32 highest-scoring positions of row 0 — worst - # case for top-k since these are exactly the tokens that would - # otherwise be selected. Use argsort instead of topk to avoid - # a known compute-sanitizer false positive in mbtopk. - top_idx = logits[0].argsort(descending=True)[:32] - logits[0, top_idx] = float("nan") - elif pattern == "nan_all_rows": - logits[:, :] = float("nan") - elif pattern == "pos_inf_one_row": - logits[0, :] = float("inf") - elif pattern == "neg_inf_one_row": - logits[0, :] = float("-inf") - elif pattern == "mixed_inf_nan": - assert self.BATCH >= 3 - logits[0, :] = float("nan") - logits[1, :] = float("inf") - logits[2, :] = float("-inf") - elif pattern == "degenerate_flat": - logits[:, :] = 1.0 - else: - raise ValueError(f"unknown pattern: {pattern}") - return logits - - def _check_tokens(self, tokens: torch.Tensor, ctx: str): - assert tokens.dim() == 1, f"{ctx}: expected 1-D output, got {tokens.shape}" - assert tokens.shape[0] == self.BATCH, ( - f"{ctx}: expected batch size {self.BATCH}, got {tokens.shape[0]}" - ) - ids = tokens.tolist() - min_id, max_id = min(ids), max(ids) - assert 0 <= min_id < self.VOCAB and 0 <= max_id < self.VOCAB, ( - f"{ctx}: token id(s) outside [0, {self.VOCAB}): min={min_id}, max={max_id}" - ) - - @pytest.mark.parametrize( - "pattern", - [ - "clean", - "nan_one_row", - "nan_few", - "nan_at_top", - "nan_all_rows", - "pos_inf_one_row", - "neg_inf_one_row", - "mixed_inf_nan", - "degenerate_flat", - ], - ) - @pytest.mark.parametrize("path", ["topk_only", "topp_only", "topk_topp"]) - def test_flashinfer_handles_pathological_logits(self, pattern: str, path: str): - """flashinfer_sample must return valid ids even on poisoned logits. - - Direct call into ``flashinfer_sample`` — exactly the code path - ``TopKTopPSampler.forward_cuda`` takes when FI is enabled. - """ - from vllm.v1.sample.ops.topk_topp_sampler import flashinfer_sample - - logits = self._make_logits(pattern) - k = ( - torch.full( - (self.BATCH,), - self.TOPK, - device=DEVICE_TYPE, - dtype=torch.int32, - ) - if path in ("topk_only", "topk_topp") - else None - ) - p = ( - torch.full( - (self.BATCH,), - self.TOPP, - device=DEVICE_TYPE, - dtype=torch.float32, - ) - if path in ("topp_only", "topk_topp") - else None - ) - - # flashinfer_sample may mutate its input in-place; pass a clone so - # the parametrize iterations stay independent. - tokens = flashinfer_sample(logits.clone().contiguous(), k, p, {}) - # Surface any async CUDA error synchronously (e.g. illegal memory - # access from a malformed FlashInfer call) so it's attributed to - # this test rather than a later, unrelated GPU op. - torch.accelerator.synchronize() - self._check_tokens(tokens, ctx=f"pattern={pattern}, path={path}") - - -# ============================================================================= -# FlashInfer top-k/top-p distribution-match tests -# ============================================================================= - - -@pytest.mark.skipif( - not FLASHINFER_TOPK_TOPP_SUPPORTED, - reason="FlashInfer top-k/top-p sampler requires CUDA " - "and a GPU with FlashInfer support.", -) -class TestFlashInferDistributionMatch: - """Chi-square goodness-of-fit: FlashInfer and PyTorch-native samplers - both reproduce the expected token distribution after top-k / top-p. - - Regression guard against historical FlashInfer distribution-shift. - Each impl is compared to the theoretical distribution (softmax of - filtered logits); if both pass they are statistically equivalent - to each other by transitivity. - """ - - VOCAB = 32 - N_SAMPLES = 50_000 - ALPHA = 1e-6 - SEED = 0 - - @pytest.mark.parametrize( - "topk,topp", - [ - (8, None), - (16, None), - (None, 0.5), - (None, 0.7), - (None, 0.99), - (8, 0.9), - (4, 0.5), - ], - ) - def test_distribution_matches_theoretical(self, topk, topp): - from scipy.stats import chisquare - - from vllm.v1.sample.ops.topk_topp_sampler import ( - apply_top_k_top_p, - flashinfer_sample, - random_sample, - ) - - torch.set_default_device(DEVICE_TYPE) - torch.manual_seed(self.SEED) - - # Same logits row used for both impls so the comparison is fair. - logits_one = ( - torch.randn( - (1, self.VOCAB), - dtype=torch.float32, - ) - * 2.0 - ) - - # Theoretical expected distribution from PyTorch-native filter. - k_one = torch.tensor([topk], dtype=torch.int32) if topk is not None else None - p_one = torch.tensor([topp], dtype=torch.float32) if topp is not None else None - masked = apply_top_k_top_p_pytorch(logits_one.clone(), k_one, p_one) - expected_probs = masked.softmax(dim=-1).flatten().cpu().numpy() - expected_counts = expected_probs * self.N_SAMPLES - - # Build a batch of N identical rows for both impls. - batch = logits_one.expand(self.N_SAMPLES, self.VOCAB).contiguous() - k_batch = ( - torch.full((self.N_SAMPLES,), topk, dtype=torch.int32) - if topk is not None - else None - ) - p_batch = ( - torch.full((self.N_SAMPLES,), topp, dtype=torch.float32) - if topp is not None - else None - ) - - # FlashInfer dispatch path. - fi_tokens = flashinfer_sample(batch.contiguous(), k_batch, p_batch, {}) - fi_counts = torch.bincount(fi_tokens, minlength=self.VOCAB).cpu().numpy() - self._chi2_check( - fi_counts, - expected_counts, - chisquare, - label=f"flashinfer top-k={topk} top-p={topp}", - ) - - # PyTorch-native dispatch path (Triton-routed filter + Gumbel sample). - processed = apply_top_k_top_p(batch.clone(), k_batch, p_batch) - probs = processed.softmax(dim=-1, dtype=torch.float32) - pt_tokens = random_sample(probs, {}) - pt_counts = torch.bincount(pt_tokens, minlength=self.VOCAB).cpu().numpy() - self._chi2_check( - pt_counts, - expected_counts, - chisquare, - label=f"native top-k={topk} top-p={topp}", - ) - - def _chi2_check(self, empirical, expected, chisquare_fn, *, label): - import numpy as np - - # Hard check: the sampler must never produce a token outside the - # expected support (zero theoretical probability). - outside = (expected == 0) & (empirical > 0) - assert not outside.any(), ( - f"{label}: sampled out-of-support tokens " - f"(zero expected prob): indices={outside.nonzero()[0].tolist()}" - ) - # Skip chi-square in the degenerate case where the support - # collapses to a single token (e.g. very restrictive joint - # top-k + top-p): all samples must land there and the hard - # check above already verified they do. - in_support = expected > 0 - if int(in_support.sum()) <= 1: - return - # Soft check: chi-square goodness-of-fit on in-support tokens. - # Cast to float64 so the rescaling step below stays within - # scipy.chisquare's strict 1.5e-8 sum-equality tolerance. - emp = empirical[in_support].astype(np.float64) - exp = expected[in_support].astype(np.float64) - exp = exp * (emp.sum() / exp.sum()) - chi2, p_value = chisquare_fn(emp, exp) - assert p_value > self.ALPHA, ( - f"{label}: distribution differs from theoretical: " - f"chi2={chi2:.2f} p_value={p_value:.2e} alpha={self.ALPHA}" - ) diff --git a/vllm/envs.py b/vllm/envs.py index fa7a5fe70f95..85f673041551 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -45,7 +45,7 @@ NO_COLOR: bool = False VLLM_LOG_STATS_INTERVAL: float = 10.0 VLLM_TRACE_FUNCTION: int = 0 - VLLM_USE_FLASHINFER_SAMPLER: bool = True + VLLM_USE_FLASHINFER_SAMPLER: bool | None = None VLLM_PP_LAYER_PARTITION: str | None = None VLLM_CPU_KVCACHE_SPACE: int | None = 0 VLLM_CPU_OMP_THREADS_BIND: str = "auto" @@ -712,13 +712,11 @@ def _get_or_set_default() -> str: # If set to 1, vllm will trace function calls # Useful for debugging "VLLM_TRACE_FUNCTION": lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), - # Whether to use the FlashInfer top-k / top-p sampler on CUDA. Enabled - # by default when the hardware supports it — set to 0 to opt out - # explicitly, which forces the PyTorch-native (Triton for bs>=8) path. + # If set, vllm will use flashinfer sampler "VLLM_USE_FLASHINFER_SAMPLER": lambda: ( bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ - else True + else None ), # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 70843be39695..33f7090e4e3d 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -41,35 +41,23 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: capability = current_platform.get_device_capability() assert capability is not None - if FlashInferBackend.supports_compute_capability(capability): - logger.info_once( - "Using FlashInfer for top-p & top-k sampling.", - scope="global", - ) - self.forward = self.forward_cuda - elif envs.is_set("VLLM_USE_FLASHINFER_SAMPLER"): - # User explicitly opted in but the GPU can't run FlashInfer. + if not FlashInferBackend.supports_compute_capability(capability): capability_str = capability.as_version_str() raise RuntimeError( "FlashInfer does not support compute capability " f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1." ) - else: - # Default-on path; hardware can't run FlashInfer → - # quietly fall back to the PyTorch-native sampler - # instead of failing server startup. - logger.warning_once( - "FlashInfer top-p/top-k sampling not supported on " - "compute capability %s; falling back to PyTorch-native " - "sampler. Set VLLM_USE_FLASHINFER_SAMPLER=0 to silence.", - capability.as_version_str(), - ) - self.forward = self.forward_native - else: - # User explicitly set VLLM_USE_FLASHINFER_SAMPLER=0. + # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1. logger.info_once( - "FlashInfer top-p/top-k sampling disabled via " - "VLLM_USE_FLASHINFER_SAMPLER=0; using PyTorch-native sampler." + "Using FlashInfer for top-p & top-k sampling.", + scope="global", + ) + self.forward = self.forward_cuda + else: + logger.debug_once( + "FlashInfer top-p/top-k sampling is available but disabled " + "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in " + "after verifying accuracy for your workloads." ) self.forward = self.forward_native @@ -132,9 +120,9 @@ def forward_cuda( p: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """More optimized implementation for top-k and top-p sampling.""" - # Fall back to the PyTorch-native path when FlashInfer has nothing - # to do (no top-k / top-p filter) or when per-request generators - # are present (unsupported by FlashInfer 0.2.3+). + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. if (k is None and p is None) or generators: if generators: logger.debug_once( @@ -373,6 +361,10 @@ def flashinfer_sample( NOTE: The outputs of this function do not necessarily match the outputs of the `random_sample` function. It only guarantees that the outputs are statistically equivalent. + + NOTE: This function includes CPU-GPU synchronization, while `random_sample` + does not. Call this function at the end of the forward pass to minimize + the synchronization overhead. """ import flashinfer