Skip to content
Closed
8 changes: 4 additions & 4 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Common dependencies
-r requirements-common.txt

# Dependencies for TPU
cmake>=3.26
ninja
Expand All @@ -9,15 +8,16 @@ setuptools-scm>=8
wheel
jinja2
ray[default]

# Install torch_xla
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241216+cpu
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
Empty file added tests/v1/tpu/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions tests/v1/tpu/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
from time import time

import pytest

from vllm import LLM, envs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams

if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)


@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_compilation(model_name: str):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
# Compiling model init may still take some time, enforce_eager to skip it.
llm = LLM(model_name,
enforce_eager=True,
max_num_seqs=16,
max_model_len=1024,
gpu_memory_utilization=0.5)
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
]
# First inference should be slow
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # too slow!
# top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
_ = llm.generate(prompts, sampling_params)
run1 = time() - s

# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=24)
s = time()
_ = llm.generate(prompts, sampling_params)
run2 = time() - s

# much faster after compiling
assert run1 * 0.1 > run2

# Third request with min_p set to "None". It will not trigger recompilation
# as a default 0 value will be used.
sampling_params = SamplingParams(max_tokens=24, temperature=1.0)
s = time()
_ = llm.generate(prompts, sampling_params)
run3 = time() - s

assert run1 * 0.1 > run3
16 changes: 14 additions & 2 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module):

def __init__(self):
super().__init__()
if current_platform.is_cuda:
if current_platform.is_cuda():
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
Expand All @@ -48,6 +48,8 @@ def __init__(self):
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
self.forward = self.forward_tpu
else:
self.forward = self.forward_native

Expand Down Expand Up @@ -79,6 +81,16 @@ def forward_cuda(
return random_sample(probs, generators)
return flashinfer_sample(probs, k, p, generators)

def forward_tpu(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# TODO Placeholder for TPU optimized topk/p kernel
return self.forward_native(logits, generators, k, p)


def apply_top_k_top_p(
logits: torch.Tensor,
Expand All @@ -95,7 +107,7 @@ def apply_top_k_top_p(

if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RejectionSampler(nn.Module):

def __init__(self):
super().__init__()
if current_platform.is_cuda:
if current_platform.is_cuda():
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def apply_min_p(
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing
logits[~valid_token_mask] = -float('inf')
# Apply mask using boolean indexing (xla friendly)
logits.masked_fill_(~valid_token_mask, -float("inf"))
return logits

def apply_logits_bias(
Expand Down
Loading