diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 1abde714af7c..7c5aadb8dd79 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -1,6 +1,5 @@ # Common dependencies -r requirements-common.txt - # Dependencies for TPU cmake>=3.26 ninja @@ -9,7 +8,6 @@ setuptools-scm>=8 wheel jinja2 ray[default] - # Install torch_xla --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu @@ -17,7 +15,9 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.6.0.dev20241216+cpu +torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" \ No newline at end of file diff --git a/tests/v1/tpu/__init__.py b/tests/v1/tpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py new file mode 100644 index 000000000000..6e0f65bf51b0 --- /dev/null +++ b/tests/v1/tpu/test_sampler.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +from time import time + +import pytest + +from vllm import LLM, envs +from vllm.platforms import current_platform +from vllm.sampling_params import SamplingParams + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"]) +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test needs a TPU") +def test_sampler_compilation(model_name: str): + """ + Check that no recompilation happens despite changing sampling parameters. + We can't read XLA metrics from the engine process, hence we measure time. + """ + # Compiling model init may still take some time, enforce_eager to skip it. + llm = LLM(model_name, + enforce_eager=True, + max_num_seqs=16, + max_model_len=1024, + gpu_memory_utilization=0.5) + prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + ] + # First inference should be slow + sampling_params = SamplingParams( + temperature=0.7, + # top_p=0.6, # too slow! + # top_k=10, + min_p=0.2, + max_tokens=16) + s = time() + _ = llm.generate(prompts, sampling_params) + run1 = time() - s + + # Second request with different params, but for which we + # compiled for in previous eager iteration. + sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=24) + s = time() + _ = llm.generate(prompts, sampling_params) + run2 = time() - s + + # much faster after compiling + assert run1 * 0.1 > run2 + + # Third request with min_p set to "None". It will not trigger recompilation + # as a default 0 value will be used. + sampling_params = SamplingParams(max_tokens=24, temperature=1.0) + s = time() + _ = llm.generate(prompts, sampling_params) + run3 = time() - s + + assert run1 * 0.1 > run3 diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 78c88ad8b830..5e1d3ff05c11 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module): def __init__(self): super().__init__() - if current_platform.is_cuda: + if current_platform.is_cuda(): if is_flashinfer_available: if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for @@ -48,6 +48,8 @@ def __init__(self): "native implementation of top-p & top-k sampling. For the " "best performance, please install FlashInfer.") self.forward = self.forward_native + elif current_platform.is_tpu(): + self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -79,6 +81,16 @@ def forward_cuda( return random_sample(probs, generators) return flashinfer_sample(probs, k, p, generators) + def forward_tpu( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> torch.Tensor: + # TODO Placeholder for TPU optimized topk/p kernel + return self.forward_native(logits, generators, k, p) + def apply_top_k_top_p( logits: torch.Tensor, @@ -95,7 +107,7 @@ def apply_top_k_top_p( if k is not None: # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B # Get all the top_k values. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) top_k_mask = logits_sort < top_k_mask diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 580ad44297aa..f3c5268aca09 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -23,7 +23,7 @@ class RejectionSampler(nn.Module): def __init__(self): super().__init__() - if current_platform.is_cuda: + if current_platform.is_cuda(): if is_flashinfer_available: if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 47ec26d42024..e697a3e2fa5f 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -213,8 +213,8 @@ def apply_min_p( adjusted_min_p = min_p.unsqueeze(1) * max_probabilities # Identify valid tokens using threshold comparison valid_token_mask = probability_values >= adjusted_min_p - # Apply mask using boolean indexing - logits[~valid_token_mask] = -float('inf') + # Apply mask using boolean indexing (xla friendly) + logits.masked_fill_(~valid_token_mask, -float("inf")) return logits def apply_logits_bias( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f7d72d26e045..e54643975c4a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import enum import time -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from unittest.mock import patch @@ -26,6 +26,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -69,6 +71,150 @@ class DecodeData: attn_metadata: Optional[PallasMetadata] = None +# TODO (NickLucche) keep in sync with SamplingMetadata until we can drop +# this class and support most options. +@dataclass +class TPUSupportedSamplingMetadata: + # This class exposes a more xla-friendly interfaces, in particular + # all arguments should be traceable and no optionals are allowed, + # to avoid graph recompilation on Nones. + temperature: torch.Tensor + + min_p: torch.Tensor + # Still too slow on forward_native! + top_k: torch.Tensor = None + top_p: torch.Tensor = None + + # XLA-unfriendly control flow + all_greedy: bool = False + all_random: bool = False + + # speculation not supported + spec_token_ids = None + + # Generator not supported by xla + generators: Dict[int, + torch.Generator] = field(default_factory=lambda: dict()) + + # unsupported, you need to return an extra tensor of static size BxV + max_num_logprobs = None + + # TODO No penalties for now + no_penalties: bool = True + prompt_token_ids = None + frequency_penalties = None + presence_penalties = None + repetition_penalties = None + # should use tensor + output_token_ids: List[List[int]] = field(default_factory=lambda: list()) + + min_tokens = None # impl is not vectorized + + logit_bias: List[Optional[Dict[int, float]]] = field( + default_factory=lambda: list()) + + allowed_token_ids_mask = None + + @classmethod + def from_sampling_metadata( + cls, metadata: SamplingMetadata, batch_size: int, + device: torch.device) -> "TPUSupportedSamplingMetadata": + metadata = cls._validate_sampling_metadata(metadata) + # NOTE we have to initialize default tensor-based params first and + # skip None values altogether to produce the same xla graph. + new_metadata = cls.get_default_sampling_params(batch_size, device) + + supported_params = \ + TPUSupportedSamplingMetadata._get_default_params_values() + # Copy `metadata` non-None values into `new_metadata`, while + # broadcasting tensor params to match sequence batch padding. + for p_name in supported_params: + old_val = getattr(metadata, p_name) + new_val = getattr(new_metadata, p_name) + # Branching in pre-processing will trigger re-compilation. + if isinstance(old_val, + torch.Tensor) and old_val.numel() != batch_size: + # TODO not efficient, manage a tensor of compiled size B + # Handle padded batch. + new_val[:old_val.shape[0]] = old_val + elif isinstance(old_val, torch.Tensor): + # This is either one value for all batch, standardized to batch + # size, or B values. + new_val[:] = old_val + setattr(new_metadata, p_name, new_val) + + xm.mark_step() + xm.wait_device_ops() + return new_metadata + + @classmethod + def from_single_prefill_metadata( + cls, metadata: SamplingMetadata, prefill_idx: int, + device: torch.device) -> "TPUSupportedSamplingMetadata": + # TODO tmp constructor until ragged kernel is implemented for B=1 case. + metadata = cls._validate_sampling_metadata(metadata) + new_metadata = cls.get_default_sampling_params(1, device) + + supported_params = \ + TPUSupportedSamplingMetadata._get_default_params_values() + # Copy `metadata` non-None values into `new_metadata`. + for p_name in supported_params: + old_val = getattr(metadata, p_name) + new_val = getattr(new_metadata, p_name) + if isinstance(old_val, torch.Tensor) and old_val.numel() > 1: + # Select the right prefill metadata. + new_val[:] = old_val[prefill_idx] + elif isinstance(old_val, torch.Tensor): + # num_prefills==1 or one param value for whole batch + new_val[:] = old_val + setattr(new_metadata, p_name, new_val) + xm.mark_step() + xm.wait_device_ops() + return new_metadata + + @classmethod + def get_default_sampling_params( + cls, batch_size: int, + device: torch.device) -> "TPUSupportedSamplingMetadata": + # As sampling happens on a single traced function, options + # are "disabled" by having them evaluate to an Identity op. + # Note that initialization is dependent on batch_size. + sampling_metadata_disable_value = \ + TPUSupportedSamplingMetadata._get_default_params_values() + kwargs = dict() + for p_name, default_val in sampling_metadata_disable_value.items(): + default_tensor = torch.full((batch_size, ), + default_val, + device=device) + kwargs[p_name] = default_tensor + + return cls(**kwargs) + + @staticmethod + def _validate_sampling_metadata( + sampling_metadata: SamplingMetadata) -> SamplingMetadata: + if sampling_metadata.all_greedy: + # Greedy sampling is always performed as long as temp is 0, but + # the control flow must be constant. + sampling_metadata.all_greedy = False + # TODO this is checked somewhere else already isn't it? + assert torch.count_nonzero(sampling_metadata.temperature) == 0 + return sampling_metadata + + @staticmethod + def _get_default_params_values(): + return dict( + temperature=0.0, + min_p=0.0, + # strictly disabled for now + # top_k=-1, + # top_p=0.0, + # frequency_penalties=0.0, + # presence_penalties=0.0, + # repetition_penalties=0.0, + ) + + class TPUModelRunner: def __init__( @@ -113,7 +259,7 @@ def __init__( self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() - self.model: Optional[nn.Module] = None + self.model: Optional[ModelWrapperV1] = None # Persistent batch. self.input_batch = InputBatch( @@ -284,6 +430,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_data.new_block_ids) self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) + # Check if the batch has changed. If not, we can skip copying the + # sampling metadata from CPU to GPU. + batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -301,6 +450,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) + + if batch_changed: + self.input_batch.refresh_sampling_metadata() return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 def swap_step(self): @@ -597,6 +749,8 @@ def execute_model( num_decodes = len(pd_info.decode_req_ids) decode_data = None sampled_token_ids = [0] * self.input_batch.num_reqs + sampling_metadata = self.input_batch.get_sampling_metadata( + scheduler_output.scheduled_spec_decode_tokens) # Run each prompt individually is_first = True @@ -609,6 +763,10 @@ def execute_model( num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] prompt_len = num_scheduled_tokens seq_len = req_state.num_computed_tokens + num_scheduled_tokens + # Select the portion of the sampling params corresponding to req i + prefill_sampling_meta = TPUSupportedSamplingMetadata\ + .from_single_prefill_metadata(sampling_metadata,\ + req_index, self.device) # Prepare first prompt if is_first: @@ -622,7 +780,8 @@ def execute_model( assert self.model is not None selected_token_ids = self.model(prompt_data.input_tokens, prompt_data.input_positions, - self.kv_caches) + self.kv_caches, + prefill_sampling_meta) # In parallel to TPU execution, prepare the next iteration if i < num_prompts - 1: @@ -649,7 +808,10 @@ def execute_model( # Run decodes (a single batch) if num_decodes > 0: - + decode_sampling_meta = TPUSupportedSamplingMetadata.\ + from_sampling_metadata(sampling_metadata, + _get_padded_batch_size(num_decodes), + self.device) # Prepare decode (if was not yet prepared) if decode_data is None: decode_data = self._prepare_decode(pd_info.decode_req_ids) @@ -660,8 +822,8 @@ def execute_model( assert self.model is not None selected_token_ids = self.model(decode_data.input_tokens, decode_data.input_positions, - self.kv_caches) - + self.kv_caches, + decode_sampling_meta) # Transfer sampled tokens from TPU to CPU decode_token_ids_cpu = selected_token_ids.cpu() # Convert to list @@ -722,6 +884,7 @@ def load_model(self) -> None: xm.mark_step() xm.wait_device_ops() model = ModelWrapperV1(model) + self.model = model self.model = torch.compile(model, backend="openxla", fullgraph=True, @@ -828,15 +991,16 @@ def dummy_run( torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + # To allow sampling, trace the forward with all supported sampling args + sampling_meta = TPUSupportedSamplingMetadata.\ + get_default_sampling_params(num_tokens, self.device) with set_forward_context(attn_metadata, self.vllm_config, 0): - assert self.model is not None - self.model(token_ids, position_ids, kv_caches) + assert self.model + self.model(token_ids, position_ids, kv_caches, sampling_meta) def capture_model(self) -> None: """Compile the model.""" @@ -865,7 +1029,7 @@ def capture_model(self) -> None: end - start) # Prefix prefill - if self.scheduler_config.enable_chunked_prefill: + if self.cache_config.enable_prefix_caching: logger.info("Compiling the model with different input shapes for " "prefix prefill:") start = time.time() @@ -955,22 +1119,27 @@ class ModelWrapperV1(nn.Module): def __init__(self, model: nn.Module): super().__init__() self.model = model + self.sampler = Sampler() + + def sample( + self, logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: + sampler_out = self.sampler(logits, sampling_metadata) + sampled_token_ids = sampler_out.sampled_token_ids + return sampled_token_ids.squeeze(dim=-1) def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + sampling_metadata: TPUSupportedSamplingMetadata, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. Args: token_ids: The input token IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len]. - input_lens: The actual input lengths of shape [batch_size]. - t: The sampling temperature of shape [batch_size]. - p: The top-p probability of shape [batch_size]. - num_samples: Number of samples to draw from each logits vector. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. """ @@ -996,16 +1165,11 @@ def forward( slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping - assert self.model is not None hidden_states = self.model(token_ids, position_ids) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, None) - - # Greedy sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - return argmax_token_ids + return self.sample(logits, sampling_metadata) def swap_positions(b: InputBatch, id_1, id_2):