diff --git a/.buildkite/test_areas/samplers.yaml b/.buildkite/test_areas/samplers.yaml index 2052a379827a..6fafda70bd6f 100644 --- a/.buildkite/test_areas/samplers.yaml +++ b/.buildkite/test_areas/samplers.yaml @@ -10,7 +10,9 @@ steps: - tests/samplers - tests/conftest.py commands: - - pytest -v -s samplers + # VLLM_USE_FLASHINFER_SAMPLER defaults to 1 now, so we need to pin both + # values explicitly to still cover the PyTorch-native (Triton) path. + - VLLM_USE_FLASHINFER_SAMPLER=0 pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers mirror: amd: diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 01d395b1e0d8..e410daf2fcdd 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -881,7 +881,7 @@ def test_apc_common_prefix_same_batch( "hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501 "hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501 ] - sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20) + sampling_params = SamplingParams(temperature=0.0, max_tokens=20) outputs = llm.generate(prompts, sampling_params) for output in outputs: assert "two" in output.outputs[0].text diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index 23f1f1c1f98a..659577b754f6 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -13,6 +13,30 @@ VOCAB_SIZE = 128 * 1024 +def _flashinfer_topk_topp_supported() -> bool: + """True iff the FlashInfer top-k/top-p sampler is usable on this host. + + Mirrors the gate in `TopKTopPSampler.__init__`: CUDA + flashinfer + importable + GPU compute capability supported by the FlashInfer + backend. + """ + if not current_platform.is_cuda(): + return False + try: + import flashinfer # noqa: F401 + + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + except ImportError: + return False + capability = current_platform.get_device_capability() + if capability is None: + return False + return FlashInferBackend.supports_compute_capability(capability) + + +FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported() + + @pytest.fixture(autouse=True) def reset_default_device(): """ @@ -568,3 +592,280 @@ def test_mixed_neginf_and_normal_rows(self): finite_in = (logits[i] > float("-inf")).sum().item() if finite_in > 0: assert kept > 0, f"Row {i}: no tokens kept" + + +# ============================================================================= +# FlashInfer top-k/top-p robustness tests +# ============================================================================= + + +@pytest.mark.skipif( + not FLASHINFER_TOPK_TOPP_SUPPORTED, + reason="FlashInfer top-k/top-p sampler requires CUDA " + "and a GPU with FlashInfer support.", +) +class TestFlashInferTopkToppRobustness: + """Robustness of FlashInfer top-k / top-p sampling to NaN / Inf logits. + + The FlashInfer sampler is enabled by default on supported GPUs. A + single poisoned request (NaN / +Inf / -Inf in row 0) must not: + + 1. crash or hang the process; + 2. produce out-of-range token ids (anything outside ``[0, vocab)``); + 3. corrupt other batch rows — neighbours of a poisoned row must + still receive valid token ids (regression for cross-row + corruption in a DP batch where one bad request would otherwise + poison its peers). + + The reference is "no crash + valid token ids", not bit-exact equality + against the PyTorch-native path. + """ + + BATCH = 8 + VOCAB = 32768 + TOPK = 50 + TOPP = 0.9 + + @pytest.fixture(autouse=True) + def setup(self): + torch.set_default_device(DEVICE_TYPE) + self.generator = Generator(device=DEVICE_TYPE).manual_seed(1234) + + def _make_logits(self, pattern: str) -> torch.Tensor: + """Build (BATCH, VOCAB) logits with `pattern` applied to row 0 + (rows 1..B-1 stay clean so we can detect cross-row corruption).""" + logits = ( + torch.randn( + self.BATCH, + self.VOCAB, + generator=self.generator, + dtype=torch.float32, + ) + * 5.0 + ) + if pattern == "clean": + return logits + if pattern == "nan_one_row": + logits[0, :] = float("nan") + elif pattern == "nan_few": + # Scatter 16 NaNs across row 0, keep the rest finite. + idx = torch.randperm(self.VOCAB, generator=self.generator)[:16] + logits[0, idx] = float("nan") + elif pattern == "nan_at_top": + # Poison the top-32 highest-scoring positions of row 0 — worst + # case for top-k since these are exactly the tokens that would + # otherwise be selected. Use argsort instead of topk to avoid + # a known compute-sanitizer false positive in mbtopk. + top_idx = logits[0].argsort(descending=True)[:32] + logits[0, top_idx] = float("nan") + elif pattern == "nan_all_rows": + logits[:, :] = float("nan") + elif pattern == "pos_inf_one_row": + logits[0, :] = float("inf") + elif pattern == "neg_inf_one_row": + logits[0, :] = float("-inf") + elif pattern == "mixed_inf_nan": + assert self.BATCH >= 3 + logits[0, :] = float("nan") + logits[1, :] = float("inf") + logits[2, :] = float("-inf") + elif pattern == "degenerate_flat": + logits[:, :] = 1.0 + else: + raise ValueError(f"unknown pattern: {pattern}") + return logits + + def _check_tokens(self, tokens: torch.Tensor, ctx: str): + assert tokens.dim() == 1, f"{ctx}: expected 1-D output, got {tokens.shape}" + assert tokens.shape[0] == self.BATCH, ( + f"{ctx}: expected batch size {self.BATCH}, got {tokens.shape[0]}" + ) + ids = tokens.tolist() + min_id, max_id = min(ids), max(ids) + assert 0 <= min_id < self.VOCAB and 0 <= max_id < self.VOCAB, ( + f"{ctx}: token id(s) outside [0, {self.VOCAB}): min={min_id}, max={max_id}" + ) + + @pytest.mark.parametrize( + "pattern", + [ + "clean", + "nan_one_row", + "nan_few", + "nan_at_top", + "nan_all_rows", + "pos_inf_one_row", + "neg_inf_one_row", + "mixed_inf_nan", + "degenerate_flat", + ], + ) + @pytest.mark.parametrize("path", ["topk_only", "topp_only", "topk_topp"]) + def test_flashinfer_handles_pathological_logits(self, pattern: str, path: str): + """flashinfer_sample must return valid ids even on poisoned logits. + + Direct call into ``flashinfer_sample`` — exactly the code path + ``TopKTopPSampler.forward_cuda`` takes when FI is enabled. + """ + from vllm.v1.sample.ops.topk_topp_sampler import flashinfer_sample + + logits = self._make_logits(pattern) + k = ( + torch.full( + (self.BATCH,), + self.TOPK, + device=DEVICE_TYPE, + dtype=torch.int32, + ) + if path in ("topk_only", "topk_topp") + else None + ) + p = ( + torch.full( + (self.BATCH,), + self.TOPP, + device=DEVICE_TYPE, + dtype=torch.float32, + ) + if path in ("topp_only", "topk_topp") + else None + ) + + # flashinfer_sample may mutate its input in-place; pass a clone so + # the parametrize iterations stay independent. + tokens = flashinfer_sample(logits.clone().contiguous(), k, p, {}) + # Surface any async CUDA error synchronously (e.g. illegal memory + # access from a malformed FlashInfer call) so it's attributed to + # this test rather than a later, unrelated GPU op. + torch.accelerator.synchronize() + self._check_tokens(tokens, ctx=f"pattern={pattern}, path={path}") + + +# ============================================================================= +# FlashInfer top-k/top-p distribution-match tests +# ============================================================================= + + +@pytest.mark.skipif( + not FLASHINFER_TOPK_TOPP_SUPPORTED, + reason="FlashInfer top-k/top-p sampler requires CUDA " + "and a GPU with FlashInfer support.", +) +class TestFlashInferDistributionMatch: + """Chi-square goodness-of-fit: FlashInfer and PyTorch-native samplers + both reproduce the expected token distribution after top-k / top-p. + + Regression guard against historical FlashInfer distribution-shift. + Each impl is compared to the theoretical distribution (softmax of + filtered logits); if both pass they are statistically equivalent + to each other by transitivity. + """ + + VOCAB = 32 + N_SAMPLES = 50_000 + ALPHA = 1e-6 + SEED = 0 + + @pytest.mark.parametrize( + "topk,topp", + [ + (8, None), + (16, None), + (None, 0.5), + (None, 0.7), + (None, 0.99), + (8, 0.9), + (4, 0.5), + ], + ) + def test_distribution_matches_theoretical(self, topk, topp): + from scipy.stats import chisquare + + from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, + flashinfer_sample, + random_sample, + ) + + torch.set_default_device(DEVICE_TYPE) + torch.manual_seed(self.SEED) + + # Same logits row used for both impls so the comparison is fair. + logits_one = ( + torch.randn( + (1, self.VOCAB), + dtype=torch.float32, + ) + * 2.0 + ) + + # Theoretical expected distribution from PyTorch-native filter. + k_one = torch.tensor([topk], dtype=torch.int32) if topk is not None else None + p_one = torch.tensor([topp], dtype=torch.float32) if topp is not None else None + masked = apply_top_k_top_p_pytorch(logits_one.clone(), k_one, p_one) + expected_probs = masked.softmax(dim=-1).flatten().cpu().numpy() + expected_counts = expected_probs * self.N_SAMPLES + + # Build a batch of N identical rows for both impls. + batch = logits_one.expand(self.N_SAMPLES, self.VOCAB).contiguous() + k_batch = ( + torch.full((self.N_SAMPLES,), topk, dtype=torch.int32) + if topk is not None + else None + ) + p_batch = ( + torch.full((self.N_SAMPLES,), topp, dtype=torch.float32) + if topp is not None + else None + ) + + # FlashInfer dispatch path. + fi_tokens = flashinfer_sample(batch.contiguous(), k_batch, p_batch, {}) + fi_counts = torch.bincount(fi_tokens, minlength=self.VOCAB).cpu().numpy() + self._chi2_check( + fi_counts, + expected_counts, + chisquare, + label=f"flashinfer top-k={topk} top-p={topp}", + ) + + # PyTorch-native dispatch path (Triton-routed filter + Gumbel sample). + processed = apply_top_k_top_p(batch.clone(), k_batch, p_batch) + probs = processed.softmax(dim=-1, dtype=torch.float32) + pt_tokens = random_sample(probs, {}) + pt_counts = torch.bincount(pt_tokens, minlength=self.VOCAB).cpu().numpy() + self._chi2_check( + pt_counts, + expected_counts, + chisquare, + label=f"native top-k={topk} top-p={topp}", + ) + + def _chi2_check(self, empirical, expected, chisquare_fn, *, label): + import numpy as np + + # Hard check: the sampler must never produce a token outside the + # expected support (zero theoretical probability). + outside = (expected == 0) & (empirical > 0) + assert not outside.any(), ( + f"{label}: sampled out-of-support tokens " + f"(zero expected prob): indices={outside.nonzero()[0].tolist()}" + ) + # Skip chi-square in the degenerate case where the support + # collapses to a single token (e.g. very restrictive joint + # top-k + top-p): all samples must land there and the hard + # check above already verified they do. + in_support = expected > 0 + if int(in_support.sum()) <= 1: + return + # Soft check: chi-square goodness-of-fit on in-support tokens. + # Cast to float64 so the rescaling step below stays within + # scipy.chisquare's strict 1.5e-8 sum-equality tolerance. + emp = empirical[in_support].astype(np.float64) + exp = expected[in_support].astype(np.float64) + exp = exp * (emp.sum() / exp.sum()) + chi2, p_value = chisquare_fn(emp, exp) + assert p_value > self.ALPHA, ( + f"{label}: distribution differs from theoretical: " + f"chi2={chi2:.2f} p_value={p_value:.2e} alpha={self.ALPHA}" + ) diff --git a/vllm/envs.py b/vllm/envs.py index 806aed2a0414..73278bf8dc4f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -45,7 +45,7 @@ NO_COLOR: bool = False VLLM_LOG_STATS_INTERVAL: float = 10.0 VLLM_TRACE_FUNCTION: int = 0 - VLLM_USE_FLASHINFER_SAMPLER: bool | None = None + VLLM_USE_FLASHINFER_SAMPLER: bool = True VLLM_PP_LAYER_PARTITION: str | None = None VLLM_CPU_KVCACHE_SPACE: int | None = 0 VLLM_CPU_OMP_THREADS_BIND: str = "auto" @@ -712,11 +712,13 @@ def _get_or_set_default() -> str: # If set to 1, vllm will trace function calls # Useful for debugging "VLLM_TRACE_FUNCTION": lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), - # If set, vllm will use flashinfer sampler + # Whether to use the FlashInfer top-k / top-p sampler on CUDA. Enabled + # by default when the hardware supports it — set to 0 to opt out + # explicitly, which forces the PyTorch-native (Triton for bs>=8) path. "VLLM_USE_FLASHINFER_SAMPLER": lambda: ( bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ - else None + else True ), # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 33f7090e4e3d..70843be39695 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -41,23 +41,35 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: capability = current_platform.get_device_capability() assert capability is not None - if not FlashInferBackend.supports_compute_capability(capability): + if FlashInferBackend.supports_compute_capability(capability): + logger.info_once( + "Using FlashInfer for top-p & top-k sampling.", + scope="global", + ) + self.forward = self.forward_cuda + elif envs.is_set("VLLM_USE_FLASHINFER_SAMPLER"): + # User explicitly opted in but the GPU can't run FlashInfer. capability_str = capability.as_version_str() raise RuntimeError( "FlashInfer does not support compute capability " f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1." ) - # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1. - logger.info_once( - "Using FlashInfer for top-p & top-k sampling.", - scope="global", - ) - self.forward = self.forward_cuda + else: + # Default-on path; hardware can't run FlashInfer → + # quietly fall back to the PyTorch-native sampler + # instead of failing server startup. + logger.warning_once( + "FlashInfer top-p/top-k sampling not supported on " + "compute capability %s; falling back to PyTorch-native " + "sampler. Set VLLM_USE_FLASHINFER_SAMPLER=0 to silence.", + capability.as_version_str(), + ) + self.forward = self.forward_native else: - logger.debug_once( - "FlashInfer top-p/top-k sampling is available but disabled " - "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in " - "after verifying accuracy for your workloads." + # User explicitly set VLLM_USE_FLASHINFER_SAMPLER=0. + logger.info_once( + "FlashInfer top-p/top-k sampling disabled via " + "VLLM_USE_FLASHINFER_SAMPLER=0; using PyTorch-native sampler." ) self.forward = self.forward_native @@ -120,9 +132,9 @@ def forward_cuda( p: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """More optimized implementation for top-k and top-p sampling.""" - # We prefer `random_sample` over `flashinfer_sample` when sorting is - # not needed. This is because `random_sample` does not require - # CPU-GPU synchronization while `flashinfer_sample` does. + # Fall back to the PyTorch-native path when FlashInfer has nothing + # to do (no top-k / top-p filter) or when per-request generators + # are present (unsupported by FlashInfer 0.2.3+). if (k is None and p is None) or generators: if generators: logger.debug_once( @@ -361,10 +373,6 @@ def flashinfer_sample( NOTE: The outputs of this function do not necessarily match the outputs of the `random_sample` function. It only guarantees that the outputs are statistically equivalent. - - NOTE: This function includes CPU-GPU synchronization, while `random_sample` - does not. Call this function at the end of the forward pass to minimize - the synchronization overhead. """ import flashinfer