Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .buildkite/test_areas/samplers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ steps:
- tests/samplers
- tests/conftest.py
commands:
# VLLM_USE_FLASHINFER_SAMPLER defaults to 1 now, so we need to pin both
# values explicitly to still cover the PyTorch-native (Triton) path.
- VLLM_USE_FLASHINFER_SAMPLER=0 pytest -v -s samplers
- pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
mirror:
amd:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def test_apc_common_prefix_same_batch(
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
]
sampling_params = SamplingParams(temperature=0.0, max_tokens=20)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using stochastic sampling (temperature=0.8) in a regression test for prefix caching can introduce flakiness and makes it harder to detect subtle state corruption. Since the LLM is initialized with a fixed seed (line 877), greedy sampling (temperature=0.0) would provide a more robust and deterministic check for the prefix caching logic, ensuring that the cached hidden states produce identical results to the non-cached path.

Suggested change
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20)
sampling_params = SamplingParams(temperature=0.0, max_tokens=20)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
assert "two" in output.outputs[0].text
301 changes: 0 additions & 301 deletions tests/v1/sample/test_topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,6 @@
VOCAB_SIZE = 128 * 1024


def _flashinfer_topk_topp_supported() -> bool:
"""True iff the FlashInfer top-k/top-p sampler is usable on this host.

Mirrors the gate in `TopKTopPSampler.__init__`: CUDA + flashinfer
importable + GPU compute capability supported by the FlashInfer
backend.
"""
if not current_platform.is_cuda():
return False
try:
import flashinfer # noqa: F401

from vllm.v1.attention.backends.flashinfer import FlashInferBackend
except ImportError:
return False
capability = current_platform.get_device_capability()
if capability is None:
return False
return FlashInferBackend.supports_compute_capability(capability)


FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported()


@pytest.fixture(autouse=True)
def reset_default_device():
"""
Expand Down Expand Up @@ -592,280 +568,3 @@ def test_mixed_neginf_and_normal_rows(self):
finite_in = (logits[i] > float("-inf")).sum().item()
if finite_in > 0:
assert kept > 0, f"Row {i}: no tokens kept"


# =============================================================================
# FlashInfer top-k/top-p robustness tests
# =============================================================================


@pytest.mark.skipif(
not FLASHINFER_TOPK_TOPP_SUPPORTED,
reason="FlashInfer top-k/top-p sampler requires CUDA "
"and a GPU with FlashInfer support.",
)
class TestFlashInferTopkToppRobustness:
"""Robustness of FlashInfer top-k / top-p sampling to NaN / Inf logits.

The FlashInfer sampler is enabled by default on supported GPUs. A
single poisoned request (NaN / +Inf / -Inf in row 0) must not:

1. crash or hang the process;
2. produce out-of-range token ids (anything outside ``[0, vocab)``);
3. corrupt other batch rows — neighbours of a poisoned row must
still receive valid token ids (regression for cross-row
corruption in a DP batch where one bad request would otherwise
poison its peers).

The reference is "no crash + valid token ids", not bit-exact equality
against the PyTorch-native path.
"""

BATCH = 8
VOCAB = 32768
TOPK = 50
TOPP = 0.9

@pytest.fixture(autouse=True)
def setup(self):
torch.set_default_device(DEVICE_TYPE)
self.generator = Generator(device=DEVICE_TYPE).manual_seed(1234)

def _make_logits(self, pattern: str) -> torch.Tensor:
"""Build (BATCH, VOCAB) logits with `pattern` applied to row 0
(rows 1..B-1 stay clean so we can detect cross-row corruption)."""
logits = (
torch.randn(
self.BATCH,
self.VOCAB,
generator=self.generator,
dtype=torch.float32,
)
* 5.0
)
if pattern == "clean":
return logits
if pattern == "nan_one_row":
logits[0, :] = float("nan")
elif pattern == "nan_few":
# Scatter 16 NaNs across row 0, keep the rest finite.
idx = torch.randperm(self.VOCAB, generator=self.generator)[:16]
logits[0, idx] = float("nan")
elif pattern == "nan_at_top":
# Poison the top-32 highest-scoring positions of row 0 — worst
# case for top-k since these are exactly the tokens that would
# otherwise be selected. Use argsort instead of topk to avoid
# a known compute-sanitizer false positive in mbtopk.
top_idx = logits[0].argsort(descending=True)[:32]
logits[0, top_idx] = float("nan")
elif pattern == "nan_all_rows":
logits[:, :] = float("nan")
elif pattern == "pos_inf_one_row":
logits[0, :] = float("inf")
elif pattern == "neg_inf_one_row":
logits[0, :] = float("-inf")
elif pattern == "mixed_inf_nan":
assert self.BATCH >= 3
logits[0, :] = float("nan")
logits[1, :] = float("inf")
logits[2, :] = float("-inf")
elif pattern == "degenerate_flat":
logits[:, :] = 1.0
else:
raise ValueError(f"unknown pattern: {pattern}")
return logits

def _check_tokens(self, tokens: torch.Tensor, ctx: str):
assert tokens.dim() == 1, f"{ctx}: expected 1-D output, got {tokens.shape}"
assert tokens.shape[0] == self.BATCH, (
f"{ctx}: expected batch size {self.BATCH}, got {tokens.shape[0]}"
)
ids = tokens.tolist()
min_id, max_id = min(ids), max(ids)
assert 0 <= min_id < self.VOCAB and 0 <= max_id < self.VOCAB, (
f"{ctx}: token id(s) outside [0, {self.VOCAB}): min={min_id}, max={max_id}"
)

@pytest.mark.parametrize(
"pattern",
[
"clean",
"nan_one_row",
"nan_few",
"nan_at_top",
"nan_all_rows",
"pos_inf_one_row",
"neg_inf_one_row",
"mixed_inf_nan",
"degenerate_flat",
],
)
@pytest.mark.parametrize("path", ["topk_only", "topp_only", "topk_topp"])
def test_flashinfer_handles_pathological_logits(self, pattern: str, path: str):
"""flashinfer_sample must return valid ids even on poisoned logits.

Direct call into ``flashinfer_sample`` — exactly the code path
``TopKTopPSampler.forward_cuda`` takes when FI is enabled.
"""
from vllm.v1.sample.ops.topk_topp_sampler import flashinfer_sample

logits = self._make_logits(pattern)
k = (
torch.full(
(self.BATCH,),
self.TOPK,
device=DEVICE_TYPE,
dtype=torch.int32,
)
if path in ("topk_only", "topk_topp")
else None
)
p = (
torch.full(
(self.BATCH,),
self.TOPP,
device=DEVICE_TYPE,
dtype=torch.float32,
)
if path in ("topp_only", "topk_topp")
else None
)

# flashinfer_sample may mutate its input in-place; pass a clone so
# the parametrize iterations stay independent.
tokens = flashinfer_sample(logits.clone().contiguous(), k, p, {})
# Surface any async CUDA error synchronously (e.g. illegal memory
# access from a malformed FlashInfer call) so it's attributed to
# this test rather than a later, unrelated GPU op.
torch.accelerator.synchronize()
self._check_tokens(tokens, ctx=f"pattern={pattern}, path={path}")


# =============================================================================
# FlashInfer top-k/top-p distribution-match tests
# =============================================================================


@pytest.mark.skipif(
not FLASHINFER_TOPK_TOPP_SUPPORTED,
reason="FlashInfer top-k/top-p sampler requires CUDA "
"and a GPU with FlashInfer support.",
)
class TestFlashInferDistributionMatch:
"""Chi-square goodness-of-fit: FlashInfer and PyTorch-native samplers
both reproduce the expected token distribution after top-k / top-p.

Regression guard against historical FlashInfer distribution-shift.
Each impl is compared to the theoretical distribution (softmax of
filtered logits); if both pass they are statistically equivalent
to each other by transitivity.
"""

VOCAB = 32
N_SAMPLES = 50_000
ALPHA = 1e-6
SEED = 0

@pytest.mark.parametrize(
"topk,topp",
[
(8, None),
(16, None),
(None, 0.5),
(None, 0.7),
(None, 0.99),
(8, 0.9),
(4, 0.5),
],
)
def test_distribution_matches_theoretical(self, topk, topp):
from scipy.stats import chisquare

from vllm.v1.sample.ops.topk_topp_sampler import (
apply_top_k_top_p,
flashinfer_sample,
random_sample,
)

torch.set_default_device(DEVICE_TYPE)
torch.manual_seed(self.SEED)

# Same logits row used for both impls so the comparison is fair.
logits_one = (
torch.randn(
(1, self.VOCAB),
dtype=torch.float32,
)
* 2.0
)

# Theoretical expected distribution from PyTorch-native filter.
k_one = torch.tensor([topk], dtype=torch.int32) if topk is not None else None
p_one = torch.tensor([topp], dtype=torch.float32) if topp is not None else None
masked = apply_top_k_top_p_pytorch(logits_one.clone(), k_one, p_one)
expected_probs = masked.softmax(dim=-1).flatten().cpu().numpy()
expected_counts = expected_probs * self.N_SAMPLES

# Build a batch of N identical rows for both impls.
batch = logits_one.expand(self.N_SAMPLES, self.VOCAB).contiguous()
k_batch = (
torch.full((self.N_SAMPLES,), topk, dtype=torch.int32)
if topk is not None
else None
)
p_batch = (
torch.full((self.N_SAMPLES,), topp, dtype=torch.float32)
if topp is not None
else None
)

# FlashInfer dispatch path.
fi_tokens = flashinfer_sample(batch.contiguous(), k_batch, p_batch, {})
fi_counts = torch.bincount(fi_tokens, minlength=self.VOCAB).cpu().numpy()
self._chi2_check(
fi_counts,
expected_counts,
chisquare,
label=f"flashinfer top-k={topk} top-p={topp}",
)

# PyTorch-native dispatch path (Triton-routed filter + Gumbel sample).
processed = apply_top_k_top_p(batch.clone(), k_batch, p_batch)
probs = processed.softmax(dim=-1, dtype=torch.float32)
pt_tokens = random_sample(probs, {})
pt_counts = torch.bincount(pt_tokens, minlength=self.VOCAB).cpu().numpy()
self._chi2_check(
pt_counts,
expected_counts,
chisquare,
label=f"native top-k={topk} top-p={topp}",
)

def _chi2_check(self, empirical, expected, chisquare_fn, *, label):
import numpy as np

# Hard check: the sampler must never produce a token outside the
# expected support (zero theoretical probability).
outside = (expected == 0) & (empirical > 0)
assert not outside.any(), (
f"{label}: sampled out-of-support tokens "
f"(zero expected prob): indices={outside.nonzero()[0].tolist()}"
)
# Skip chi-square in the degenerate case where the support
# collapses to a single token (e.g. very restrictive joint
# top-k + top-p): all samples must land there and the hard
# check above already verified they do.
in_support = expected > 0
if int(in_support.sum()) <= 1:
return
# Soft check: chi-square goodness-of-fit on in-support tokens.
# Cast to float64 so the rescaling step below stays within
# scipy.chisquare's strict 1.5e-8 sum-equality tolerance.
emp = empirical[in_support].astype(np.float64)
exp = expected[in_support].astype(np.float64)
exp = exp * (emp.sum() / exp.sum())
chi2, p_value = chisquare_fn(emp, exp)
assert p_value > self.ALPHA, (
f"{label}: distribution differs from theoretical: "
f"chi2={chi2:.2f} p_value={p_value:.2e} alpha={self.ALPHA}"
)
8 changes: 3 additions & 5 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
NO_COLOR: bool = False
VLLM_LOG_STATS_INTERVAL: float = 10.0
VLLM_TRACE_FUNCTION: int = 0
VLLM_USE_FLASHINFER_SAMPLER: bool = True
VLLM_USE_FLASHINFER_SAMPLER: bool | None = None
VLLM_PP_LAYER_PARTITION: str | None = None
VLLM_CPU_KVCACHE_SPACE: int | None = 0
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
Expand Down Expand Up @@ -712,13 +712,11 @@ def _get_or_set_default() -> str:
# If set to 1, vllm will trace function calls
# Useful for debugging
"VLLM_TRACE_FUNCTION": lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")),
# Whether to use the FlashInfer top-k / top-p sampler on CUDA. Enabled
# by default when the hardware supports it — set to 0 to opt out
# explicitly, which forces the PyTorch-native (Triton for bs>=8) path.
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER": lambda: (
bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ
else True
else None
),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
Expand Down
Loading
Loading