From 566398e44d1331bde96749bcc2563e1a954e4b71 Mon Sep 17 00:00:00 2001 From: Carol Zheng Date: Mon, 12 May 2025 12:06:59 -0700 Subject: [PATCH 01/15] [CI/Build] Fix TPU V1 Test mixed use of & and && across tests (#17968) Signed-off-by: Siyuan Liu --- .../scripts/hardware_ci/run-tpu-v1-test.sh | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 939daddad92b..2d375d7e9d87 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -26,27 +26,27 @@ docker run --privileged --net host --shm-size=16G -it \ && tpu-info \ && { \ echo TEST_0: Running test_perf.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \ echo TEST_0_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_1: Running test_compilation.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \ echo TEST_1_EXIT_CODE: \$?; \ } & \ { \ echo TEST_2: Running test_basic.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \ echo TEST_2_EXIT_CODE: \$?; \ } & \ { \ echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ - pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ + python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ echo TEST_3_EXIT_CODE: \$?; \ } & \ { \ echo TEST_4: Running test_quantization_accuracy.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \ echo TEST_4_EXIT_CODE: \$?; \ } & \ { \ @@ -56,43 +56,43 @@ docker run --privileged --net host --shm-size=16G -it \ } & \ { \ echo TEST_6: Running test_tpu_model_runner.py; \ - pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \ echo TEST_6_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_7: Running test_sampler.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \ echo TEST_7_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_8: Running test_topk_topp_sampler.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \ echo TEST_8_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_9: Running test_multimodal.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \ echo TEST_9_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_10: Running test_pallas.py; \ - pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \ echo TEST_10_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_11: Running test_struct_output_generate.py; \ - pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \ echo TEST_11_EXIT_CODE: \$?; \ } & \ - && { \ + { \ echo TEST_12: Running test_moe_pallas.py; \ - pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \ echo TEST_12_EXIT_CODE: \$?; \ } & \ # Disable the TPU LoRA tests until the feature is activated - # && { \ + # & { \ # echo TEST_13: Running test_moe_pallas.py; \ - # pytest -s -v /workspace/vllm/tests/tpu/lora/; \ + # python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/; \ # echo TEST_13_EXIT_CODE: \$?; \ # } & \ wait \ From 9bff53d41369cd60177a75bc6069d985b1f42cbc Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Tue, 13 May 2025 03:09:16 +0800 Subject: [PATCH 02/15] [Core] Use platform-agnostic device control for DP engine core (#17245) Signed-off-by: Jade Zheng Signed-off-by: Siyuan Liu --- vllm/platforms/cuda.py | 26 ++++---------------------- vllm/platforms/interface.py | 19 +++++++++++++++++++ vllm/platforms/rocm.py | 11 +---------- vllm/v1/engine/core.py | 13 ++++++------- 4 files changed, 30 insertions(+), 39 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f116285870ec..dd3a54f7daf2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -34,24 +34,6 @@ torch.backends.cuda.enable_cudnn_sdp(False) -def device_id_to_physical_device_id(device_id: int) -> int: - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - if device_ids == [""]: - msg = ( - "CUDA_VISIBLE_DEVICES is set to empty string, which means" - " GPU support is disabled. If you are using ray, please unset" - " the environment variable `CUDA_VISIBLE_DEVICES` inside the" - " worker/actor. " - "Check https://github.com/vllm-project/vllm/issues/8402 for" - " more information.") - raise RuntimeError(msg) - physical_device_id = device_ids[device_id] - return int(physical_device_id) - else: - return device_id - - def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) @@ -338,7 +320,7 @@ def get_device_capability(cls, device_id: int = 0 ) -> Optional[DeviceCapability]: try: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) return DeviceCapability(major=major, minor=minor) @@ -360,20 +342,20 @@ def has_device_capability( @classmethod @with_nvml_context def get_device_name(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) return cls._get_physical_device_name(physical_device_id) @classmethod @with_nvml_context def get_device_uuid(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) return pynvml.nvmlDeviceGetUUID(handle) @classmethod @with_nvml_context def get_device_total_memory(cls, device_id: int = 0) -> int: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 68b90796ece2..a0c9e2ae374d 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import enum +import os import platform import random from platform import uname @@ -161,6 +162,24 @@ def is_cuda_alike(self) -> bool: def is_sleep_mode_available(self) -> bool: return self._enum == PlatformEnum.CUDA + @classmethod + def device_id_to_physical_device_id(cls, device_id: int): + if cls.device_control_env_var in os.environ: + device_ids = os.environ[cls.device_control_env_var].split(",") + if device_ids == [""]: + msg = (f"{cls.device_control_env_var} is set to empty string, " + "which means current platform support is disabled. If " + "you are using ray, please unset the environment " + f"variable `{cls.device_control_env_var}` inside the " + "worker/actor. Check " + "https://github.com/vllm-project/vllm/issues/8402 for " + "more information.") + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ea028e13fc4d..f3d64f01b0f7 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -95,15 +95,6 @@ def wrapper(*args, **kwargs): return wrapper -def device_id_to_physical_device_id(device_id: int) -> int: - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - physical_device_id = device_ids[device_id] - return int(physical_device_id) - else: - return device_id - - @cache def on_mi250_mi300() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName @@ -238,7 +229,7 @@ def is_fully_connected(physical_device_ids: List[int]) -> bool: @with_amdsmi_context @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - physical_device_id = device_id_to_physical_device_id(device_id) + physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = amdsmi_get_processor_handles()[physical_device_id] asic_info = amdsmi_get_gpu_asic_info(handle) device_name: str = asic_info["device_id"] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c1aa0ce27d3f..fde60bbfa51f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -622,13 +622,12 @@ def __init__( assert 0 <= local_dp_rank <= dp_rank < dp_size from vllm.platforms import current_platform - if current_platform.is_cuda_alike(): - from vllm.platforms.cuda import device_id_to_physical_device_id - tp_size = vllm_config.parallel_config.tensor_parallel_size - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - str(device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * - tp_size)) + device_control_env_var = current_platform.device_control_env_var + tp_size = vllm_config.parallel_config.tensor_parallel_size + os.environ[device_control_env_var] = ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * + tp_size)) self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() From df3a9d3bb92a2d9123c329d4de2093c5d131d522 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 12 May 2025 23:09:35 +0000 Subject: [PATCH 03/15] delete runner, import from tpu commons Signed-off-by: Siyuan Liu --- examples/offline_inference/tpu.py | 17 +- vllm/v1/worker/tpu_model_runner.py | 1501 ---------------------------- vllm/v1/worker/tpu_worker.py | 4 +- 3 files changed, 16 insertions(+), 1506 deletions(-) delete mode 100644 vllm/v1/worker/tpu_model_runner.py diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index 71cd88f2788a..5433db6df575 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -1,7 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 +import sys + from vllm import LLM, SamplingParams +# Using absolute path to import from tpu_commons in home directory +# Change to yours +sys.path.append('/home/lsiyuan') + prompts = [ "A robot may not injure a human being", "It is only with the heart that one can see rightly;", @@ -20,10 +26,13 @@ def main(): # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - max_num_batched_tokens=64, - max_num_seqs=4, - max_model_len=128) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + max_num_batched_tokens=64, + max_num_seqs=4, + max_model_len=128, + enforce_eager=True, + ) outputs = llm.generate(prompts, sampling_params) print("-" * 50) for output, answer in zip(outputs, answers): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py deleted file mode 100644 index 687dabee2290..000000000000 --- a/vllm/v1/worker/tpu_model_runner.py +++ /dev/null @@ -1,1501 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import bisect -import gc -import time -from typing import TYPE_CHECKING, Optional, cast -from unittest.mock import patch - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn -# TPU XLA related -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr - -import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, - PlaceholderRange) -from vllm.multimodal.utils import group_mm_inputs_by_modality -from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, - PallasMetadata) -from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, - SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) -from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata -from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler -from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin - -from .utils import sanity_check_mm_encoder_outputs - -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - -logger = init_logger(__name__) - -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 -INVALID_TOKEN_ID = -1 -# Smallest output size -MIN_NUM_SEQS = 8 - - -######################################################### -# Ways to avoid recompilation -######################################################### -# -# The model executor has two primary components: -# 1. preparing the model and sampler inputs -# 2. executing the model and sampler. -# The core idea is to avoid any TPU computation during input preparation. For -# better compilation tracking and increased flexibility, the model execution and -# sampler are divided into several distinct components. -# -# Below are the detailed steps: -# -# Step 1 -# It is recommended to avoid TPU operations when preparing the model and sampler -# inputs. CPU tensors can be prepared and transferred to the XLA device using -# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids -# compilation. -# -# Step 2 -# The TPU execution should be decomposed into subgraphs (4 at the moment): -# 1. the main model -# 2. selecting hidden states for each request -# 3. sampler -# 4. encoder. -# Each subgraph should be decorated in a torch.compile. This is used to make -# sure that we have the same subgraph topology in both dummy_run and -# xecute_model. The results from these subgraphs should either be passed to -# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for -# subsequent processing on the CPU. -# -# Step 3 -# The dummy_run should be comprehensive, ensuring all potential input shapes and -# branch predictions are included as subgraph inputs to facilitate -# pre-compilation. -class TPUModelRunner(LoRAModelRunnerMixin): - - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - ): - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - self.device_config = vllm_config.device_config - - model_config = self.model_config - cache_config = self.cache_config - scheduler_config = self.scheduler_config - parallel_config = self.parallel_config - self.device = device - self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION - - self.enforce_eager = model_config.enforce_eager - - self.num_xla_graphs = 0 - self._update_num_xla_graphs("init") - - self.pin_memory = is_pin_memory_available() - self.dtype = self.model_config.dtype - self._hidden_states_dtype = self.dtype - - self.is_multimodal_model = model_config.is_multimodal_model - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - # InputBatch needs to work with sampling tensors greater than padding - # to avoid dynamic shapes. Also, avoid suboptimal alignment. - self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) - self.num_tokens_paddings = _get_token_paddings( - min_token_size=16, - max_token_size=scheduler_config.max_num_batched_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) - # In case `max_num_tokens < max(num_tokens_paddings)` use the actual - # padded max value to pre-allocate data structures and pre-compile. - self.max_num_tokens = self.num_tokens_paddings[-1] - - # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() - self.hidden_size = model_config.get_hidden_size() - self.vocab_size = model_config.get_vocab_size() - - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.uses_mrope = model_config.uses_mrope - # TODO: Support M-RoPE (e.g, Qwen2-VL) - assert not self.uses_mrope, "TPU does not support M-RoPE yet." - - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, - mm_registry=self.mm_registry, - ) - self.max_num_encoder_input_tokens = encoder_compute_budget - self.encoder_cache_size = encoder_cache_size - - # Lazy initialization - # self.model: nn.Module # Set after load_model - self.kv_caches: list[torch.Tensor] = [] - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} - - # Request states. - self.requests: dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.vocab_size, - ) - - # Cached torch/numpy tensor - # The pytorch tensor and numpy array share the same buffer. - # Sometimes the numpy op is faster so we create both. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") - self.input_ids_np = self.input_ids_cpu.numpy() - - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") - self.positions_np = self.positions_cpu.numpy() - - self.block_table_cpu = torch.zeros( - (self.max_num_reqs, self.max_num_blocks_per_req), - dtype=self.input_batch.block_table.get_cpu_tensor().dtype, - device="cpu") - - self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.query_start_loc_np = self.query_start_loc_cpu.numpy() - - self.seq_lens_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.seq_lens_np = self.seq_lens_cpu.numpy() - - # Range tensor with values [0 .. self.max_num_tokens - 1]. - # Used to initialize positions / context_lens / seq_lens - # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) - self.num_reqs_paddings = _get_req_paddings( - min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) - - # tensors for structured decoding - self.grammar_bitmask_cpu = torch.zeros( - (self.max_num_reqs, cdiv(self.vocab_size, 32)), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.require_structured_out_cpu = torch.zeros( - (self.max_num_reqs, 1), - dtype=torch.bool, - device="cpu", - pin_memory=self.pin_memory) - self.structured_decode_arange = torch.arange( - 0, 32, device="cpu", pin_memory=self.pin_memory) - - # Get maximum number of mm items per modality (batch size). - self.max_num_mm_items_by_modality = dict() - if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 - and self.encoder_cache_size > 0): - max_tokens_by_modality_dict = ( - MULTIMODAL_REGISTRY. - get_max_tokens_per_item_by_nonzero_modality(self.model_config)) - for modality, max_tokens in max_tokens_by_modality_dict.items(): - # Check how many items of this modality can be supported by - # the encoder budget. - encoder_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) - - max_num_mm_items_encoder_budget = cdiv(encoder_budget, - max_tokens) - - # Check how many items of this modality can be supported by - # the decoder budget. - max_mm_items_per_req = self.mm_registry.\ - get_mm_limits_per_prompt(self.model_config)[modality] - - # NOTE: We do not consider max_num_batched_tokens on purpose - # because the multimodal embeddings can be generated in advance - # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * \ - max_mm_items_per_req - - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) - self.max_num_mm_items_by_modality[modality] = max_num_mm_items - - def _update_num_xla_graphs(self, case_str): - check_comp = self.check_recompilation and not self.enforce_eager - if not check_comp: - return - - total_cached_graphs = xr.get_num_cached_compilation_graph() - new_compiled_graphs = total_cached_graphs - self.num_xla_graphs - if new_compiled_graphs == 0: - return - - logger.info("Add new %d compiled XLA graphs due to %s", - new_compiled_graphs, case_str) - self.num_xla_graphs += new_compiled_graphs - - def _verify_num_xla_graphs(self, case_str): - check_comp = self.check_recompilation and not self.enforce_eager - if not check_comp: - return - - curr_cached_graph = xr.get_num_cached_compilation_graph() - assert self.num_xla_graphs == curr_cached_graph, ( - "Recompilation after warm up is detected during {}." - " num_xla_graphs = {} curr_cached_graph = {}".format( - case_str, self.num_xla_graphs, curr_cached_graph)) - - def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: - """Update the cached states and the persistent batch with the scheduler - output. - - The updated states are used by the `_prepare_inputs` function to create - the input GPU tensors for the model. - - Returns: - True if there is a new/resumed/paused/finished request. - If False, we can skip copying SamplingMetadata to the GPU. - """ - # Remove finished requests from the cached states. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) - - # Remove the finished requests from the persistent batch. - # NOTE(woosuk): There could be an edge case where finished_req_ids and - # scheduled_req_ids overlap. This happens when a request is aborted and - # then resubmitted with the same ID. In this case, we treat them as two - # distinct requests - clearing the cached states for the first request - # and handling the second as a new request. - removed_req_indices: list[int] = [] - for req_id in scheduler_output.finished_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) - - # Remove the unscheduled requests from the persistent batch. - # NOTE(woosuk): The unscheduled requests are either preempted requests - # or running requests that are not scheduled in this step. We remove - # them from the persistent batch but keep their cached states since - # they will be scheduled again sometime in the future. - scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() - cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids - # NOTE(woosuk): The persistent batch optimization assumes that - # consecutive batches contain mostly the same requests. If batches - # have low request overlap (e.g., alternating between two distinct - # sets of requests), this optimization becomes very inefficient. - for req_id in unscheduled_req_ids: - req_index = self.input_batch.remove_request(req_id) - assert req_index is not None - removed_req_indices.append(req_index) - - req_ids_to_add: list[str] = [] - # Add new requests to the cached states. - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - sampling_params = new_req_data.sampling_params - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - generator=None, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) - - req_ids_to_add.append(req_id) - - # Update the states of the running/resumed requests. - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - - # Update the cached states. - req_state.num_computed_tokens = req_data.num_computed_tokens - if not req_data.resumed_from_preemption: - # Append the new blocks to the existing block IDs. - req_state.block_ids.extend(req_data.new_block_ids) - else: - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = req_data.new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: - # The request is not in the persistent batch. - # The request was either preempted and resumed later, or was not - # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) - continue - - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - self.input_batch.block_table.append_row(req_data.new_block_ids, - req_index) - - # Add the new or resumed requests to the persistent batch. - # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None - self.input_batch.add_request(req_state, req_index) - - # Condense the batched states if there are empty indices. - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - - return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 - - def get_model(self) -> nn.Module: - assert self.model is not None - return self.model - - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - layers = get_layers_from_vllm_config(self.vllm_config, Attention) - block_size = self.vllm_config.cache_config.block_size - kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in layers.items(): - if attn_module.attn_type == AttentionType.DECODER: - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=attn_module.dtype, - sliding_window=attn_module.sliding_window, - use_mla=False, - ) - else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=attn_module.dtype, - use_mla=False, - ) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") - - return kv_cache_spec - - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - - # Get the number of scheduled tokens for each request. - num_scheduled_tokens_per_req = [] - max_num_scheduled_tokens_all_reqs = 0 - for req_id in self.input_batch.req_ids[:num_reqs]: - assert req_id is not None - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens_per_req.append(num_tokens) - max_num_scheduled_tokens_all_reqs = max( - max_num_scheduled_tokens_all_reqs, num_tokens) - num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, - dtype=np.int32) - assert max_num_scheduled_tokens_all_reqs > 0 - - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - # For each scheduled token, what are the corresponding req index. - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens_per_req) - - # Get batched arange. - # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # For each scheduled token, what is its position in corresponding req. - arange = np.concatenate( - [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) - - # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) - - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) - - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.input_batch.block_table. - slot_mapping_np[:total_num_scheduled_tokens]) - - # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - np.cumsum(num_scheduled_tokens_per_req, - out=self.query_start_loc_np[1:num_reqs + 1]) - self.query_start_loc_np[num_reqs + 1:] = 1 - - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens_per_req) - - # Do the padding and copy the tensors to the TPU. - padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens) - # Zero out to avoid spurious values from prev iteration (last cp chunk) - self.input_ids_cpu[ - total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 - self.input_ids = self.input_ids_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) - self.position_ids = self.positions_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) - self.input_batch.block_table.slot_mapping_cpu[ - total_num_scheduled_tokens:] = _PAD_SLOT_ID - slot_mapping = ( - self.input_batch.block_table. - slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( - self.device)) - block_tables = self.block_table_cpu[:self.max_num_reqs] - block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) - block_tables = block_tables.to(self.device) - query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( - self.device) - seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) - - if self.lora_config is not None: - # We need to respect padding when activating LoRA adapters - padded_num_scheduled_tokens_per_req = np.copy( - num_scheduled_tokens_per_req - ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += \ - padded_total_num_scheduled_tokens - total_num_scheduled_tokens - - self.set_active_loras(self.input_batch, - padded_num_scheduled_tokens_per_req) - - attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=seq_lens, - query_start_loc=query_start_loc, - num_seqs=torch.tensor([num_reqs], - dtype=torch.int32, - device=self.device), - ) - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # request in the batch. While we should not sample any token from this - # partial request, we do so for simplicity. We will ignore the sampled - # token from the partial request. - # TODO: Support prompt logprobs. - padded_num_reqs = _get_padded_num_reqs_with_upper_limit( - num_reqs, self.max_num_reqs) - # Indices at which we sample (positions of last token in the sequence). - # Padded to avoid recompiling when `num_reqs` varies. - logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 - logits_indices = logits_indices.to(self.device) - - layer_names = get_layers_from_vllm_config(self.vllm_config, - Attention).keys() - per_layer_attn_metadata = { - layer_name: attn_metadata - for layer_name in layer_names - } - return per_layer_attn_metadata, logits_indices, padded_num_reqs - - def _scatter_placeholders( - self, - embeds: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return embeds - - placeholders = embeds.new_full( - (is_embed.shape[0], embeds.shape[-1]), - fill_value=torch.nan, - ) - placeholders[is_embed] = embeds - return placeholders - - def _gather_placeholders( - self, - placeholders: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return placeholders - - return placeholders[is_embed] - - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): - scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs - if not scheduled_encoder_inputs: - return - - # Batch the multi-modal inputs. - mm_inputs = list[MultiModalKwargs]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() - for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_state = self.requests[req_id] - - for mm_input_id in encoder_input_ids: - mm_inputs.append(req_state.mm_inputs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) - - # Batch mm inputs as much as we can: if a request in the batch has - # multiple modalities or a different modality than the previous one, - # we process it separately to preserve item order. - # FIXME(ywang96): This is a hacky way to deal with multiple modalities - # in the same batch while still being able to benefit from batching - # multimodal inputs. The proper solution should be reordering the - # encoder outputs. - grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) - - encoder_outputs = [] - for grouped_mm_inputs in grouped_mm_inputs_list: - batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) - - # Run the encoder. - # `curr_group_outputs` is either of the following: - # 1. A tensor of shape (num_items, feature_size, hidden_size) - # in case feature_size is fixed across all multimodal items. - # 2. A list or tuple (length: num_items) of tensors, each of shape - # (feature_size, hidden_size) in case the feature size is dynamic - # depending on the input multimodal items. - xm.mark_step() - curr_group_outputs = self.model.get_multimodal_embeddings( - **batched_mm_inputs) - xm.mark_step() - - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=len(grouped_mm_inputs), - ) - - if isinstance(curr_group_outputs, torch.Tensor): - encoder_outputs.append(curr_group_outputs) - else: - assert isinstance(curr_group_outputs, (list, tuple)) - for output in curr_group_outputs: - encoder_outputs.append(output) - - # Cache the encoder outputs. - # NOTE (NickLucche) here we diverge from logic in other runners, as we - # assume to only have whole mm items to process. Hence we avoid the - # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - assert pos_info.is_embed is None, "Expected all positions to be"\ - " contiguous and embeddings." - self.encoder_cache[req_id][input_id] = output - - def _gather_mm_embeddings( - self, - scheduler_output: "SchedulerOutput", - ) -> list[torch.Tensor]: - mm_embeds: list[torch.Tensor] = [] - for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens - mm_positions = req_state.mm_positions - # TODO unroll loop and assume/enforce --disable_chunked_mm_input - # NOTE (NickLucche) here we diverge from logic in other runners, as - # we assume to only have whole mm items to process. Hence we avoid - # the intrinsic dynamism that `gather_mm_placeholders` introduces. - for i, pos_info in enumerate(mm_positions): - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length - - # The encoder output is needed if the two ranges overlap: - # [num_computed_tokens, - # num_computed_tokens + num_scheduled_tokens) and - # [start_pos, start_pos + num_encoder_tokens) - if start_pos >= num_computed_tokens + num_scheduled_tokens: - # The encoder output is not needed in this step. - break - if start_pos + num_encoder_tokens <= num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - continue - - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - assert pos_info.is_embed is None, "Expected all positions to"\ - " be contiguous and embeddings." - encoder_output = self.encoder_cache[req_id][i] - mm_embeds.append(encoder_output) - return mm_embeds - - def _get_model_inputs(self, input_ids: torch.Tensor, - mm_embeds: list[torch.Tensor]): - if self.is_multimodal_model: - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - input_ids, mm_embeds) - else: - inputs_embeds = self.model.get_input_embeddings(input_ids) - return None, inputs_embeds - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - return input_ids, None - - @torch.no_grad() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> ModelRunnerOutput: - # Update cached state - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - xm.mark_step() - # Prepare inputs - attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs( - scheduler_output) - input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embeds) - xm.mark_step() - num_reqs = self.input_batch.num_reqs - # Run the decoder - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=scheduler_output.total_num_scheduled_tokens): - hidden_states = self.model( - input_ids=input_ids, - positions=self.position_ids, - inputs_embeds=inputs_embeds, - ) - hidden_states = self.select_hidden_states(hidden_states, - logits_indices) - logits = self.compute_logits(hidden_states) - tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ - from_input_batch(self.input_batch, padded_num_reqs, self.device) - if scheduler_output.grammar_bitmask is not None: - require_struct_decoding, grammar_bitmask_padded, arange = \ - self.prepare_structured_decoding_input(logits, scheduler_output) - logits = self.structured_decode(require_struct_decoding, - grammar_bitmask_padded, logits, - arange) - selected_token_ids = self.sample_from_logits(logits, - tpu_sampling_metadata) - - # NOTE (NickLucche) Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. We can't enforce it due - # to recompilations outside torch.compiled code, so just make sure - # `sample_from_logits` does not modify the logits in-place. - logprobs = self.gather_logprobs(logits, selected_token_ids) \ - if tpu_sampling_metadata.logprobs else None - - # Remove padding on cpu and keep dynamic op outside of xla graph. - selected_token_ids = selected_token_ids.cpu()[:num_reqs] - logprobs_lists = logprobs.tolists() \ - if tpu_sampling_metadata.logprobs else None - - # Update the cache state concurrently. Code above will not block until - # we use `selected_token_ids`. Add mark_step if post-processing changes - request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] - discard_sampled_tokens_req_indices = [] - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len >= req_state.num_tokens: - request_seq_lens.append((i, req_state, seq_len)) - else: - # Ignore the sampled token from the partial request. - # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - # This relies on cuda-specific torch-internal impl details - generator.set_offset(generator.get_offset() - 4) - - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) - - assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) - - prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} - for req_id in self.input_batch.req_ids[:num_reqs]: - prompt_logprobs_dict[req_id] = None - - max_gen_len = selected_token_ids.shape[-1] - if max_gen_len == 1: - valid_sampled_token_ids = selected_token_ids.tolist() - - # Mask out the sampled tokens that should not be sampled. - # TODO: Keep in sync with gpu_model_runner.py, in particular - # the "else" case here - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() - - # Append sampled tokens - for i, req_state, seq_len in request_seq_lens: - token_id = valid_sampled_token_ids[i][0] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - self.input_batch.num_tokens[i] += 1 - - else: - valid_mask = selected_token_ids != INVALID_TOKEN_ID - gen_lens = valid_mask.sum(dim=1).tolist() - valid_sampled_token_ids = [ - seq.tolist() - for seq in selected_token_ids[valid_mask].split(gen_lens) - ] - self.input_batch.num_tokens[:num_reqs] += gen_lens - for i, req_state, seq_len in request_seq_lens: - target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[ - i, target_slice] = valid_sampled_token_ids[i] - req_state.output_token_ids.extend(valid_sampled_token_ids[i]) - - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=None, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - ) - - # Check there are no new graphs compiled - all the graphs should be - # captured and compiled during warm up. - self._verify_num_xla_graphs("execute_model") - - return model_runner_output - - def load_model(self) -> None: - self.device = self.device_config.device - - # NOTE(woosuk): While the executor assigns the TP ranks to the worker - # process, the ranks can be different from the ranks internally assigned - # by the xm runtime. Therefore, there is a mismatch in the rank - # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. - # This is not a problem in linear layers because all-reduce is - # rank-agnostic. However, it matters for all-gather as the ranks - # determine the order of concatenating the output tensors. - # As a workaround, we use the xm's rank assignment only when loading - # the embedding weights. - xm_tp_rank = xr.global_ordinal() - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): - model = get_model(vllm_config=self.vllm_config) - if self.lora_config is not None: - model = self.load_lora_model(model, self.model_config, - self.scheduler_config, - self.lora_config, self.device) - - # Sync all pending XLA execution during model initialization and weight - # loading. - xm.mark_step() - xm.wait_device_ops() - self.model = model - self.sampler = TPUSampler() - - @torch.no_grad() - def _dummy_run(self, num_tokens: int) -> None: - if self.is_multimodal_model: - input_ids = None - inputs_embeds = torch.zeros((num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) - else: - input_ids = torch.zeros((num_tokens), - dtype=torch.int32, - device=self.device) - inputs_embeds = None - actual_num_reqs = min(num_tokens, self.max_num_reqs) - position_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros(num_tokens, - dtype=torch.int64, - device=self.device) - block_tables = torch.zeros( - (self.max_num_reqs, self.block_table_cpu.shape[1]), - dtype=torch.int32, - device=self.device) - query_lens = [1] * self.max_num_reqs - query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.int32), - dim=0, - dtype=torch.int32).to(self.device) - context_lens = torch.ones((self.max_num_reqs, ), - dtype=torch.int32, - device=self.device) - num_seqs = torch.tensor([actual_num_reqs], - dtype=torch.int32, - device=self.device) - attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - query_start_loc=query_start_loc, - num_seqs=num_seqs, - ) - - if self.is_multimodal_model: - torch._dynamo.mark_dynamic(inputs_embeds, 0) - else: - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - - layer_names = get_layers_from_vllm_config(self.vllm_config, - Attention).keys() - per_layer_attn_metadata = { - layer_name: attn_metadata - for layer_name in layer_names - } - - with self.maybe_dummy_run_with_lora( - self.lora_config, - np.array([num_tokens], dtype=np.int32)), set_forward_context( - per_layer_attn_metadata, self.vllm_config, 0): - out = self.model(input_ids=input_ids, - positions=position_ids, - inputs_embeds=inputs_embeds) - self._hidden_states_dtype = out.dtype - - def _precompile_mm_encoder(self) -> None: - # Pre-compile MM encoder for all supported data modalities. - hf_config = self.vllm_config.model_config.hf_config - for mode, max_items_by_mode in \ - self.max_num_mm_items_by_modality.items(): - logger.info( - "Compiling Multimodal %s Encoder with different input" - " shapes.", mode) - start = time.perf_counter() - # No padding for MM encoder just yet. - for num_items in range(1, max_items_by_mode + 1): - logger.info(" -- mode: %s items: %d", mode, num_items) - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - mode, num_items) - # Run multimodal encoder. - xm.mark_step() - mm_embeds = self.model.\ - get_multimodal_embeddings(**batched_dummy_mm_inputs) - xm.mark_step() - num_patches = mm_embeds[0].shape[0] - items_size = num_patches * num_items - - # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm - # embeddings are present. We assume `--disable-mm-chunked`, - # hence only whole items can be scheduled. This implies we just - # need to compile when `num_items` fit the (padded) `input_ids` - for num_tokens in self.num_tokens_paddings: - if num_tokens >= items_size: - # XLA Workaround: if torch.zeros(..device) is used, XLA - # compiles a scalar+expansion op, which won't match - # the graph generated at runtime. CPU->TPU must be used - placeholders_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device="cpu") - # Align placeholders and actual num mm_embeddings. - placeholders_ids[:items_size] = \ - hf_config.image_token_index - - placeholders_ids = placeholders_ids.to(self.device) - # Assign outputs or the graph will be cut short. - a, b = self._get_model_inputs(placeholders_ids, - [mm_embeds]) - assert a is None - xm.mark_step() - - # Pre-compile `get_input_embeddings` when mm_embeddings are not - # present. Chunk is only made of text, no mm_placeholders. - for num_tokens in self.num_tokens_paddings: - placeholders_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device="cpu") - placeholders_ids = placeholders_ids.to(self.device) - a, b = self._get_model_inputs(placeholders_ids, []) - assert a is None - xm.mark_step() - - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal %s Encoder compilation finished in in %.2f " - "[secs].", mode, end - start) - - def _precompile_backbone(self) -> None: - logger.info("Compiling the model with different input shapes.") - start = time.perf_counter() - for num_tokens in self.num_tokens_paddings: - logger.info(" -- num_tokens: %d", num_tokens) - self._dummy_run(num_tokens) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("model backbone") - - def _precompile_select_hidden_states(self) -> None: - # Compile hidden state selection function for bucketed - # n_tokens x max_num_reqs. Graph is really small so this is fine. - logger.info( - "Compiling select_hidden_states with different input shapes.") - start = time.perf_counter() - hsize = self.model_config.get_hidden_size() - for num_tokens in self.num_tokens_paddings: - dummy_hidden = torch.zeros((num_tokens, hsize), - device=self.device, - dtype=self._hidden_states_dtype) - torch._dynamo.mark_dynamic(dummy_hidden, 0) - for num_reqs in self.num_reqs_paddings: - indices = torch.zeros(num_reqs, - dtype=torch.int32, - device=self.device) - torch._dynamo.mark_dynamic(indices, 0) - self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, - num_reqs) - # Requests can't be more than tokens. But do compile for the - # next bigger value in case num_tokens uses bucketed padding. - if num_reqs >= min(num_tokens, self.max_num_reqs): - break - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("select_hidden_states") - - def _precompile_compute_logits(self) -> None: - logger.info("Compiling compute_logits with different input shapes.") - start = time.perf_counter() - hsize = self.model_config.get_hidden_size() - for num_reqs in self.num_reqs_paddings: - dummy_hidden = torch.zeros((num_reqs, hsize), - device=self.device, - dtype=self._hidden_states_dtype) - torch._dynamo.mark_dynamic(dummy_hidden, 0) - self.compute_logits(dummy_hidden) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("compute_logits") - - def _precompile_structured_decoding(self) -> None: - logger.info( - "Compiling structured_decoding with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - dummy_require_struct_decoding = \ - self.require_structured_out_cpu[:num_reqs].to(self.device) - dummy_grammar_bitmask = \ - self.grammar_bitmask_cpu[:num_reqs].to(self.device) - # The first dimension of the above 3 dummy tensors cannot be - # mark_dynamic because some operations in structured_decode require - # them to be static. - arange = self.structured_decode_arange.to(self.device) - self.structured_decode(dummy_require_struct_decoding, - dummy_grammar_bitmask, dummy_logits, arange) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("structured_decoding") - - def _precompile_sample_from_logits(self) -> None: - logger.info( - "Compiling sample_from_logits with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - # The first dimension of dummy_logits cannot be mark_dynamic - # because some operations in the sampler require it to be static. - for all_greedy in [False, True]: - generate_params_if_all_greedy = not all_greedy - sampling_metadata = ( - TPUSupportedSamplingMetadata.from_input_batch( - self.input_batch, - num_reqs, - self.device, - generate_params_if_all_greedy, - )) - sampling_metadata.all_greedy = all_greedy - self.sample_from_logits(dummy_logits, sampling_metadata) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("sample_from_logits") - - def _precompile_gather_logprobs(self) -> None: - logger.info("Compiling gather_logprobs with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - dummy_tokens = torch.zeros((num_reqs, 1), - dtype=torch.int64).to(self.device) - self.gather_logprobs(dummy_logits, dummy_tokens) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("gather_logprobs") - - def capture_model(self) -> None: - """ - Precompile all the subgraphs with possible input shapes. - """ - self._precompile_mm_encoder() - self._precompile_backbone() - self._precompile_select_hidden_states() - self._precompile_compute_logits() - self._precompile_structured_decoding() - self._precompile_sample_from_logits() - self._precompile_gather_logprobs() - - def profile_run( - self, - num_tokens: int, - ) -> None: - # Profile with multimodal encoder & encoder cache. - # TODO: handle encoder-decoder models once we support them. - if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 - and self.encoder_cache_size > 0): - - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - dummy_data_modality, max_num_mm_items = max( - self.max_num_mm_items_by_modality.items(), key=lambda t: t[1]) - - encoder_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) - - logger.info( - "Encoder cache will be initialized with a budget of %d tokens," - " and profiled with %s %s items of the maximum feature size.", - encoder_budget, max_num_mm_items, dummy_data_modality) - - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_data_modality, max_num_mm_items) - - # Run multimodal encoder. - # Isolate encoder graph from post-processing to minimize - # impact of recompilation until it's fixed. - start = time.perf_counter() - xm.mark_step() - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - xm.mark_step() - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal Encoder profiling finished in in %.2f [secs].", - end - start) - - assert len(dummy_encoder_outputs) == max_num_mm_items, ( - "Expected dimension 0 of encoder outputs to match the number " - f"of multimodal data items: {max_num_mm_items}, got " - f"{len(dummy_encoder_outputs)=} instead. This is most likely " - "due to the 'get_multimodal_embeddings' method of the model " - "not implemented correctly.") - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - - # Trigger compilation for general shape. - self._dummy_run(num_tokens) - - xm.mark_step() - xm.wait_device_ops() - self.encoder_cache.clear() - gc.collect() - - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: - """ - Initialize KV cache based on `kv_cache_config`. - Args: - kv_cache_config: Configuration for the KV cache, including the KV - cache size of each layer - """ - if len(kv_cache_config.kv_cache_groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") - - kv_caches: dict[str, torch.Tensor] = {} - - for kv_cache_group in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - dtype = kv_cache_spec.dtype - - tpu_kv_cache = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - - kv_caches[layer_name] = tpu_kv_cache - else: - raise NotImplementedError - - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) - - def reset_dynamo_cache(self): - if self.is_multimodal_model: - compiled_model = self.model.get_language_model().model - else: - compiled_model = self.model.model - if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): - logger.info("Clear dynamo cache and cached dynamo bytecode.") - torch._dynamo.eval_frame.remove_from_cache( - compiled_model.original_code_object) - compiled_model.compiled_codes.clear() - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def select_hidden_states(self, hidden_states, indices_do_sample): - return hidden_states[indices_do_sample] - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def compute_logits(self, - sample_hidden_states: torch.Tensor) -> torch.Tensor: - return self.model.compute_logits(sample_hidden_states, None) - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def sample_from_logits( - self, logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: - """ - Sample with xla-friendly function. This function is to be traced - separately from `forward` for lighter compilation overhead. - """ - if sampling_metadata.all_greedy: - out_tokens = torch.argmax(logits, dim=-1, keepdim=True) - else: - out_tokens = self.sampler(logits, - sampling_metadata).sampled_token_ids - return out_tokens - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def gather_logprobs(self, logits: torch.Tensor, - sampled_tokens: torch.Tensor) -> LogprobsTensors: - """ - Gather the top_logprobs with corresponding tokens. Use a fixed number - of logprobs as an alternative to having multiple pre-compiled graphs. - Select the number of logprobs actually demanded by each request on CPU. - """ - logprobs = self.sampler.compute_logprobs(logits) - return self.sampler.gather_logprobs( - logprobs, - self.model_config.max_logprobs, - token_ids=sampled_tokens.squeeze(-1)) - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def structured_decode(self, require_struct_decoding: torch.Tensor, - grammar_bitmask: torch.Tensor, logits: torch.Tensor, - arange: torch.Tensor) -> torch.Tensor: - return torch.where( - require_struct_decoding, - self.apply_grammar_bitmask(logits, grammar_bitmask, arange), - logits) - - def apply_grammar_bitmask(self, logits: torch.Tensor, - grammar_bitmask: torch.Tensor, - arange: torch.Tensor): - assert (logits.shape[0] == grammar_bitmask.shape[0]) - logits_cloned = logits.clone() - for i in range(logits.shape[0]): - unpacked_bitmask = (torch.bitwise_right_shift( - grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 - unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size] - logits_cloned[i] = logits_cloned[i].masked_fill( - unpacked_bitmask, -float("inf")) - return logits_cloned - - def get_multimodal_embeddings(self, *args, **kwargs): - return self.model.get_multimodal_embeddings(*args, **kwargs) - - def get_input_embeddings(self, *args, **kwargs): - return self.model.get_input_embeddings(*args, **kwargs) - - def prepare_structured_decoding_input( - self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - grammar_bitmask = scheduler_output.grammar_bitmask - assert grammar_bitmask is not None - num_reqs, _ = logits.shape - - # Reset pre-allocated tensors - self.grammar_bitmask_cpu.zero_() - self.require_structured_out_cpu.zero_() - - # We receive the structured output bitmask from the scheduler, but the - # indices of the requests in the batch may not match the indices of - # the bitmask since the scheduler doesn't know how the tpu runner is - # ordering the requests in the batch. We need to match the order of - # bitmask with the order of requests - struct_out_indices: list[int] = [] - mask_indices: list[int] = [] - for req_id in self.input_batch.req_ids: - mask_index = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_index is None: - continue - batch_index = self.input_batch.req_id_to_index[req_id] - struct_out_indices.append(batch_index) - mask_indices.append(mask_index) - self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( - grammar_bitmask[mask_indices]) - # It's not guaranteed that all requests in this batch require - # structured output, so create a bool tensor to represent - # the requests that need structured output. - struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) - self.require_structured_out_cpu[struct_out_indices] = True - return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ - self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ - self.structured_decode_arange.to(logits.device) - - def _get_mm_dummy_batch(self, modality: str, - batch_size: int) -> BatchedTensorInputs: - # Dummy data for pre-compiling multimodal models. - dummy_request_data = self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=self.max_num_tokens, - ) - dummy_mm_data = dummy_request_data.multi_modal_data - - # Dummy data definition in V0 may contain multiple multimodal items - # (e.g, multiple images) for a single request, therefore here we - # always replicate first item by max_num_mm_items times since in V1 - # they are scheduled to be processed separately. - assert isinstance(dummy_mm_data, MultiModalKwargs), ( - "Expected dummy multimodal data to be of type " - f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. " - "This is most likely due to the model not having a merged " - "processor.") - - # When models have a merged processor, their dummy data is - # already batched `MultiModalKwargs`, therefore we take the first - # `MultiModalKwargsItem` from the desired modality to profile on. - dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) - - batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * - batch_size) - return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, - device=self.device) - - -def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: - logger.info("Preparing request paddings:") - # assert min_req_size is power of 2 - assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0 - paddings: list = [] - num = max(MIN_NUM_SEQS, min_req_size) - while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num): - paddings.append(num) - logger.info(" %d", num) - num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size) - return paddings - - -def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int: - res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() - return min(res, upper_limit) - - -def _get_token_paddings(min_token_size: int, max_token_size: int, - padding_gap: int) -> list[int]: - """Generate a list of padding size, starting from min_token_size, - ending with a number that can cover max_token_size - - If padding_gap == 0 then: - increase 2X each time (exponential) - else: - first increase the size to twice, - then increase the padding size by padding_gap. - """ - # assert min_token_size is power of 2 - assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0 - paddings = [] - num = min_token_size - - if padding_gap == 0: - logger.info("Using exponential token paddings:") - while True: - logger.info(" %d", num) - paddings.append(num) - if num >= max_token_size: - break - num *= 2 - else: - logger.info("Using incremental token paddings:") - while num <= padding_gap: - logger.info(" %d", num) - paddings.append(num) - num *= 2 - num //= 2 - while num < max_token_size: - num += padding_gap - logger.info(" %d", num) - paddings.append(num) - - return paddings - - -def _get_padded_token_len(paddings: list[int], x: int) -> int: - """Return the first element in paddings list greater or equal to x. - """ - index = bisect.bisect_left(paddings, x) - assert index < len(paddings) - return paddings[index] diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9eea26d85249..2c5f4d552683 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -9,6 +9,9 @@ import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp import torch_xla.runtime as xr +# from vllm.v1.worker.tpu_model_runner import TPUModelRunner +# import from tpu_commons +from tpu_commons.runner.tpu_torch_xla_runner import TPUModelRunner import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig @@ -23,7 +26,6 @@ KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache, report_usage_stats -from vllm.v1.worker.tpu_model_runner import TPUModelRunner logger = init_logger(__name__) From d81d960a7c18628322784afbf4b55dd93bc23970 Mon Sep 17 00:00:00 2001 From: Hongmin Fan Date: Tue, 13 May 2025 16:33:48 +0000 Subject: [PATCH 04/15] Small clean up for the sys.path manipulation. The same system path workaround can be achieved by `PYTHONPATH=$HOME python ...` without the hardcode in the source. Signed-off-by: Siyuan Liu --- examples/offline_inference/tpu.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index 5433db6df575..e1d9f864c5f1 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -1,13 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -import sys - from vllm import LLM, SamplingParams -# Using absolute path to import from tpu_commons in home directory -# Change to yours -sys.path.append('/home/lsiyuan') - prompts = [ "A robot may not injure a human being", "It is only with the heart that one can see rightly;", From 082813f03fc431483b9e0bb623f3aaff247a8d69 Mon Sep 17 00:00:00 2001 From: Hongmin Fan Date: Tue, 13 May 2025 17:19:18 +0000 Subject: [PATCH 05/15] Use the TPUWorker implementation from tpu_commons. Signed-off-by: Siyuan Liu --- vllm/v1/worker/tpu_worker.py | 268 +---------------------------------- 1 file changed, 2 insertions(+), 266 deletions(-) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 2c5f4d552683..6ad42a159e90 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,270 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """A TPU worker class.""" -import os -from typing import Optional -import torch -import torch.distributed -import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.debug.profiler as xp -import torch_xla.runtime as xr -# from vllm.v1.worker.tpu_model_runner import TPUModelRunner -# import from tpu_commons -from tpu_commons.runner.tpu_torch_xla_runner import TPUModelRunner +import tpu_commons.worker.tpu_torch_xla_worker as tpu_torch_xla_worker -import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, - KVCacheSpec) -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache, report_usage_stats - -logger = init_logger(__name__) - - -class TPUWorker: - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - ): - self.is_driver_worker = is_driver_worker - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - - if self.cache_config.cache_dtype == "auto": - self.cache_dtype = self.model_config.dtype - else: - self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - self.cache_config.cache_dtype] - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - # Delay profiler initialization to the start of the profiling. - # This is because in vLLM V1, MP runtime is initialized before the - # TPU Worker is initialized. The profiler server needs to start after - # MP runtime is initialized. - self.profiler = None - self.profile_dir = None - if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: - # For TPU, we can only have 1 active profiler session for 1 profiler - # server. So we only profile on rank0. - self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - self.profile_dir) - - if self.model_config.seed is None: - self.model_config.seed = 0 - - if vllm_config.lora_config is not None: - raise NotImplementedError( - "The V1 TPU backend doesn't support LoRA serving") - - def init_device(self): - os.environ["PJRT_DEVICE"] = "TPU" - # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D - # ring, the xla tpu compiler flag - # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to - # fix this. It will be removed after the bug in XLA compiler is fixed. - os.environ["LIBTPU_INIT_ARGS"] = ( - "--xla_tpu_force_1d_allreduce_at_chunk_count=1") - torch.set_grad_enabled(False) - torch.set_default_dtype(self.model_config.dtype) - - # Initialize the distributed environment. - init_tpu_worker_distributed_environment(self.parallel_config, - self.rank, - self.distributed_init_method, - self.local_rank) - - # Device initialization should happen after initializing - # the distributed runtime. - self.device = xm.xla_device() - self.device_config.device = self.device - - # Set random seed. - set_random_seed(self.model_config.seed) - if self.model_config.seed is not None: - xm.set_rng_state(self.model_config.seed, self.device) - - # Increase the cache size limit, which is the maximum number of - # dynamo graphs that can be compiled. - # TODO (NickLucche) On gsm we compile 80+ graphs. - # Re-evaluate limit, with MM we may get close to this limit. - torch._dynamo.config.cache_size_limit = 128 - # Use persistent cache to avoid XLA recompilation. - # NOTE(woosuk): Set per-rank cache path since different ranks - # can have slightly different XLA graphs. - world_size = self.parallel_config.world_size - rank = xr.global_ordinal() - # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. - # Consequently, changes in optimization flags, which affect compilation - # results, don't change the cache key. This can result in the wrong - # compilation being used. To prevent this, disabling the XLA compilation - # cache during development is recommended.We can disable it by - # `export VLLM_XLA_CACHE_PATH=` - if envs.VLLM_XLA_CACHE_PATH: - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") - xr.initialize_cache(per_rank_path, readonly=False) - - # Init ModelRunner here, so that we have access to self.device. - self.model_runner = TPUModelRunner(self.vllm_config, self.device) - - if rank == 0: - # If usage stat is enabled, collect relevant info. - report_usage_stats(self.vllm_config) - - def determine_available_memory(self) -> int: - kv_caches: dict[str, torch.Tensor] = {} - kv_cache_spec = self.model_runner.get_kv_cache_spec() - for layer_name, layer_spec in kv_cache_spec.items(): - if isinstance(layer_spec, AttentionSpec): - dtype = layer_spec.dtype - - # Use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - tpu_kv_cache = torch.tensor([], - dtype=dtype, - device=self.device) - kv_caches[layer_name] = tpu_kv_cache - else: - raise NotImplementedError( - f"Unsupported KV cache spec '{type(layer_spec)}'") - - runner_kv_caches: list[torch.Tensor] = [] - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - runner_kv_caches) - - # `max_num_tokens >= max_num_batched_tokens` due to padding. - self.model_runner.profile_run(self.model_runner.max_num_tokens) - - # Synchronize before measuring the memory usage. - xm.wait_device_ops() - - # During the profiling run, the model runs without KV cache. After - # the profiling run, the model always runs with KV cache. Here we clear - # the dynamo cache and cached bytecode to ensure the model always has - # one compiled bytecode. Having one FX graph/cached bytecode per - # compiled model is required for `support_torch_compile` decorator to - # skip dynamo guard. - self.model_runner.reset_dynamo_cache() - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - m = xm.get_memory_info(self.device) - total_memory_size = m["bytes_limit"] - current_mem = m["bytes_used"] - # Ideally we would use profiled = m["peak_bytes_used"] to - # get weights + activations. But there is memory used during - # compilation / weight loading that impacts the peak and - # there is no way to reset peak memory in XLA, So we - # use the heuristic of 2% of weights. - profiled = current_mem * 1.02 - - # Calculate the TPU KV cache size based on profiling. - usable_memory_size = int(total_memory_size * - self.cache_config.gpu_memory_utilization) - tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - - return int(tpu_kv_cache_bytes) - - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: - output = self.model_runner.execute_model(scheduler_output) - return output if self.is_driver_worker else None - - def profile(self, is_start: bool = True): - if self.rank < 1: - if self.profile_dir is None: - raise RuntimeError("Profiler is not enabled.") - if is_start: - if self.profiler is None: - self.profiler = xp.start_server(9012) - xp.start_trace(self.profile_dir) - else: - xp.stop_trace() - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def load_model(self) -> None: - self.model_runner.load_model() - - def compile_or_warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model() - - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - def get_model(self) -> nn.Module: - return self.model_runner.get_model() - - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - return self.model_runner.get_kv_cache_spec() - - def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: - """Allocate GPU KV cache with the specified kv_cache_config.""" - self.model_runner.initialize_kv_cache(kv_cache_config) - - def check_health(self) -> None: - # worker will always be healthy as long as it's running. - return - - -def init_tpu_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - - # NOTE(woosuk): This is just to initialize the TP group and broadcast - # the input objects on CPU. The all-reduce and all-gather ops on TPU - # are invoked by `xm.all_reduce` and `xm.all_gather` which use their - # own context. - init_distributed_environment( - world_size=parallel_config.world_size, - rank=rank, - local_rank=local_rank, - distributed_init_method=distributed_init_method, - backend="gloo", - ) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) +TPUWorker = tpu_torch_xla_worker.TPUWorker From 01c26b248e6b240f09cae3c7e0329abab1421fae Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 13 May 2025 22:28:26 +0000 Subject: [PATCH 06/15] import workers from tpu commons Signed-off-by: Siyuan Liu --- vllm/v1/worker/tpu_worker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 6ad42a159e90..7189399efbe6 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 """A TPU worker class.""" -import tpu_commons.worker.tpu_torch_xla_worker as tpu_torch_xla_worker - -TPUWorker = tpu_torch_xla_worker.TPUWorker +from tpu_commons.worker import TPUWorker # noqa: F401 From c8af4d726ca3542439a55e1c3a4e42de3a5cbd8c Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 14 May 2025 04:44:40 +0000 Subject: [PATCH 07/15] import tpu platform from tpu commons Signed-off-by: Siyuan Liu --- vllm/platforms/tpu.py | 195 +----------------------------------------- 1 file changed, 1 insertion(+), 194 deletions(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d0a5af3587c4..4121b691bc49 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,196 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional, Tuple, Union, cast - -import torch -from tpu_info import device - -import vllm.envs as envs -from vllm.inputs import ProcessorInputs, PromptType -from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams, SamplingType - -from .interface import Platform, PlatformEnum, _Backend - -if TYPE_CHECKING: - from vllm.config import BlockSize, ModelConfig, VllmConfig - from vllm.pooling_params import PoolingParams -else: - BlockSize = None - ModelConfig = None - VllmConfig = None - PoolingParams = None - -logger = init_logger(__name__) - - -class TpuPlatform(Platform): - _enum = PlatformEnum.TPU - device_name: str = "tpu" - device_type: str = "tpu" - dispatch_key: str = "XLA" - ray_device_key: str = "TPU" - device_control_env_var: str = "TPU_VISIBLE_CHIPS" - simple_compile_backend: str = "openxla" - - supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"] - - additional_env_vars: list[str] = [ - "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" - ] - - @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: - if (selected_backend != _Backend.PALLAS - and selected_backend != _Backend.PALLAS_VLLM_V1): - logger.info("Cannot use %s backend on TPU.", selected_backend) - - if use_v1: - logger.info("Using Pallas V1 backend.") - return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" - else: - logger.info("Using Pallas backend.") - return "vllm.attention.backends.pallas.PallasAttentionBackend" - - @classmethod - def get_device_name(cls, device_id: int = 0) -> str: - chip_type, _ = device.get_local_chips() - return f"TPU {chip_type.name}" - - @classmethod - def get_device_total_memory(cls, device_id: int = 0) -> int: - raise NotImplementedError - - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return not envs.VLLM_USE_V1 - - @classmethod - def get_punica_wrapper(cls) -> str: - return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" - - @classmethod - def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: - return torch.finfo(dtype).min, torch.finfo(dtype).max - - @classmethod - def can_update_inplace(cls): - return False - - @classmethod - def get_lora_vocab_padding_size(cls) -> int: - return 1 - - @classmethod - def inference_mode(cls): - return torch.no_grad() - - @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - from vllm.config import CompilationLevel - - cache_config = vllm_config.cache_config - # For v0, the default block size is 16. - if cache_config and cache_config.block_size is None: - cache_config.block_size = cast(BlockSize, 16) - compilation_config = vllm_config.compilation_config - - # TPU only supports DYNAMO_ONCE compilation level - if compilation_config.level != CompilationLevel.DYNAMO_ONCE: - logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") - compilation_config.level = CompilationLevel.DYNAMO_ONCE - - if compilation_config.backend == "": - compilation_config.backend = "openxla" - - assert vllm_config.speculative_config is None, \ - "TPU does not support speculative decoding" - - if vllm_config.model_config.dtype in (torch.float16, torch.float32): - logger.warning( - "The TPU backend currently does not support %s. " - "Using bfloat16 instead.", vllm_config.model_config.dtype) - vllm_config.model_config.dtype = torch.bfloat16 - - if envs.VLLM_USE_V1: - from vllm.v1.attention.backends.pallas import ( - PallasAttentionBackend) - cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) # type: ignore[assignment] - min_page_size = PallasAttentionBackend.get_min_page_size( - vllm_config) - if min_page_size > cache_config.block_size: - logger.warning( - "Increase the page size from %s to %s to make sure there's" - "no SMEM OOM", - cache_config.block_size, - min_page_size, - ) - cache_config.block_size = min_page_size # type: ignore[assignment] - - parallel_config = vllm_config.parallel_config - scheduler_config = vllm_config.scheduler_config - if parallel_config.worker_cls == "auto": - if scheduler_config.is_multi_step: - if envs.VLLM_USE_V1: - raise NotImplementedError( - "Multi-step scheduling is not supported (and not " - "needed) on vLLM V1. Please launch without " - "--num-scheduler-steps.") - else: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" - else: - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.tpu_worker.TPUWorker" - else: - parallel_config.worker_cls = \ - "vllm.worker.tpu_worker.TPUWorker" - - assert not vllm_config.speculative_config, ( - "Speculative decoding is not yet supported for TPU backend") - - if scheduler_config.is_multimodal_model and not \ - scheduler_config.disable_chunked_mm_input: - logger.warning("TPU does not support running Multimodal models"\ - " without setting `--disable_chunked_mm_input`. " \ - "Forcing --disable_chunked_mm_input.") - scheduler_config.disable_chunked_mm_input = True - - @classmethod - def is_pin_memory_available(cls): - logger.warning("Pin memory is not supported on TPU.") - return False - - @classmethod - def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa - - @classmethod - def use_all_gather(cls) -> bool: - return True - - @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: - # V1 support on TPU is experimental - return True - - @classmethod - def validate_request( - cls, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - processed_inputs: ProcessorInputs, - ) -> None: - """Raises if this request is unsupported on this platform""" - if isinstance(params, SamplingParams): - if params.guided_decoding is not None and not envs.VLLM_USE_V1: - raise ValueError("Structured output is not supported on " - f"{cls.device_name} V0.") - if params.sampling_type == SamplingType.RANDOM_SEED: - raise ValueError( - "Torch XLA does not support per-request seed.") +from tpu_commons.platforms import TpuPlatform # noqa: F401 From 54d6f35f9fc1bea2dda520c4dc6ee511fe364695 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 14 May 2025 04:54:54 +0000 Subject: [PATCH 08/15] import tpu communicators from tpu commons Signed-off-by: Siyuan Liu --- .../device_communicators/tpu_communicator.py | 93 +------------------ 1 file changed, 2 insertions(+), 91 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index de66ceaeef6f..4f958652b586 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,93 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import os -from typing import Optional - -import torch -from torch.distributed import ProcessGroup - -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger -from vllm.platforms import current_platform - -from .base_device_communicator import DeviceCommunicatorBase - -USE_RAY = parallel_config = get_current_vllm_config( -).parallel_config.distributed_executor_backend == "ray" - -logger = init_logger(__name__) - -if current_platform.is_tpu(): - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.runtime as xr - from torch_xla._internal import pjrt - from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups) - - if USE_RAY: - from vllm.executor import ray_utils - - -class TpuCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): - super().__init__(cpu_group, device, device_group, unique_name) - - # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node - # must be used together. Therefore, the local rank and world size can - # be simply calculated as follows. - global_rank = self.global_rank - global_world_size = self.global_world_size - - if USE_RAY: - logger.info("TpuCommunicator initialized with RAY") - # Calculate how many TPU nodes are in the current deployment. This - # is the Ray placement group if it is deployed with Ray. Default - # to the number of TPU nodes in the Ray cluster. The number of TPU - # nodes is computed by the total number of TPUs divided by the - # number of TPU accelerators per node, to account for clusters - # with both CPUs and TPUs. - num_nodes = ray_utils.get_num_tpu_nodes() - num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() - if num_nodes_in_pg > 0: - num_nodes = num_nodes_in_pg - - local_world_size = global_world_size // num_nodes - local_rank = global_rank % local_world_size - else: - logger.info("TpuCommunicator initialized with MP") - # Sanity: Verify we run on a single host - num_hosts = torch_xla.tpu.num_tpu_workers() - assert num_hosts == 1 - - # Get the current number of TPUs (we have locally) - local_world_size = torch_xla.tpu.num_available_chips() - - # Get current rank - local_rank = global_rank % local_world_size - - # Ensure environment variables are set for multihost deployments. - # On GKE, this is needed for libtpu and TPU driver to know which TPU - # chip is actually visible. Otherwise the TPU driver will fail to - # initialize because the number of devices would be different from - # the number of visible worker addresses. - os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) - os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) - - pjrt.initialize_multiprocess(local_rank, local_world_size) - xr._init_world_size_ordinal() - self.groups = create_optimized_replica_groups() - - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - # TODO: Remove the groups specification after XLA compiler can support - # auto-reordering the ring order for all-reduce. - return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups) - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - assert dim == -1, "TPUs only support dim=-1 for all-gather." - return xm.all_gather(input_, dim=dim) +from tpu_commons.distributed.device_communicators import ( # noqa: F401 + TpuCommunicator) From 237bca74a7832f17023c8688c9bdbf577ce5f25a Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 14 May 2025 05:10:52 +0000 Subject: [PATCH 09/15] remove sampler Signed-off-by: Siyuan Liu --- vllm/v1/sample/tpu/__init__.py | 0 vllm/v1/sample/tpu/metadata.py | 123 ---------------------------- vllm/v1/sample/tpu/sampler.py | 144 --------------------------------- 3 files changed, 267 deletions(-) delete mode 100644 vllm/v1/sample/tpu/__init__.py delete mode 100644 vllm/v1/sample/tpu/metadata.py delete mode 100644 vllm/v1/sample/tpu/sampler.py diff --git a/vllm/v1/sample/tpu/__init__.py b/vllm/v1/sample/tpu/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py deleted file mode 100644 index a1c7dcdb111f..000000000000 --- a/vllm/v1/sample/tpu/metadata.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass, field -from typing import Optional - -import torch - -from vllm.v1.worker.gpu_input_batch import InputBatch - -DEFAULT_SAMPLING_PARAMS = dict( - temperature=-1.0, - min_p=0.0, - # strictly disabled for now - top_k=0, - top_p=1.0, - # frequency_penalties=0.0, - # presence_penalties=0.0, - # repetition_penalties=0.0, -) - - -@dataclass -class TPUSupportedSamplingMetadata: - # This class exposes a more xla-friendly interface than SamplingMetadata - # on TPU, in particular all arguments should be traceable and no optionals - # are allowed, to avoid graph recompilation on Nones. - temperature: torch.Tensor = None - - min_p: torch.Tensor = None - top_k: torch.Tensor = None - top_p: torch.Tensor = None - - all_greedy: bool = True - - # Whether logprobs are to be gathered in this batch of request. To balance - # out compile time and runtime, a fixed `max_number_logprobs` value is used - # when gathering logprobs, regardless of the values specified in the batch. - logprobs: bool = False - - # TODO No penalties for now - no_penalties: bool = True - prompt_token_ids = None - frequency_penalties = None - presence_penalties = None - repetition_penalties = None - # should use tensor - output_token_ids: list[list[int]] = field(default_factory=lambda: list()) - - min_tokens = None # impl is not vectorized - - logit_bias: list[Optional[dict[int, float]]] = field( - default_factory=lambda: list()) - - allowed_token_ids_mask = None - bad_words_token_ids = None - - # Generator not supported by xla - _generators: dict[int, - torch.Generator] = field(default_factory=lambda: dict()) - - @property - def generators(self) -> dict[int, torch.Generator]: - # Generator not supported by torch/xla. This field must be immutable. - return self._generators - - @classmethod - def from_input_batch( - cls, - input_batch: InputBatch, - padded_num_reqs: int, - xla_device: torch.device, - generate_params_if_all_greedy: bool = False - ) -> "TPUSupportedSamplingMetadata": - """ - Copy sampling tensors slices from `input_batch` to on device tensors. - - `InputBatch._make_sampling_metadata` causes recompilation on XLA as it - slices dynamic shapes on device tensors. This impl moves the dynamic - ops to CPU and produces tensors of fixed `padded_num_reqs` size. - - Args: - input_batch: The input batch containing sampling parameters. - padded_num_reqs: The padded number of requests. - xla_device: The XLA device. - generate_params_if_all_greedy: If True, generate sampling parameters - even if all requests are greedy. this is useful for cases where - we want to pre-compile a graph with sampling parameters, even if - they are not strictly needed for greedy decoding. - """ - needs_logprobs = input_batch.max_num_logprobs>0 if \ - input_batch.max_num_logprobs else False - # Early return to avoid unnecessary cpu to tpu copy - if (input_batch.all_greedy is True - and generate_params_if_all_greedy is False): - return cls(all_greedy=True, logprobs=needs_logprobs) - - num_reqs = input_batch.num_reqs - - def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: - # Pad value is the default one. - cpu_tensor[num_reqs:padded_num_reqs] = fill_val - - fill_slice(input_batch.temperature_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["temperature"]) - fill_slice(input_batch.min_p_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["min_p"]) - fill_slice(input_batch.top_k_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["top_k"]) - fill_slice(input_batch.top_p_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["top_p"]) - - # Slice persistent device tensors to a fixed pre-compiled padded shape. - return cls( - temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs]. - to(xla_device), - all_greedy=input_batch.all_greedy, - # TODO enable more and avoid returning None values - top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( - xla_device), - min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - logprobs=needs_logprobs) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py deleted file mode 100644 index 7c31a2984b30..000000000000 --- a/vllm/v1/sample/tpu/sampler.py +++ /dev/null @@ -1,144 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Sampler layer implementing TPU supported operations.""" - -import torch -import torch.nn as nn - -from vllm.v1.outputs import LogprobsTensors, SamplerOutput -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler -from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata - -_SAMPLING_EPS = 1e-5 - - -class Sampler(nn.Module): - - def __init__(self): - super().__init__() - self.topk_topp_sampler = TopKTopPSampler() - - def forward( - self, - logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata, - ) -> SamplerOutput: - # Use float32 for the logits. - logits = logits.to(torch.float32) - # Sample the next token. - sampled = self.sample(logits, sampling_metadata) - - # These are TPU tensors. - sampler_output = SamplerOutput( - # The sampled tokens are expanded to 2D tensor with shape - # [num_requests, 1], where each row represents one generated - # token per request. - sampled_token_ids=sampled.unsqueeze(-1), - logprobs_tensors=None) - return sampler_output - - def apply_temperature( - self, - logits: torch.Tensor, - temp: torch.Tensor, - ) -> torch.Tensor: - return logits.div_(temp.unsqueeze(dim=1)) - - def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: - return logits.argmax(dim=-1).view(-1) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata, - ) -> torch.Tensor: - greedy_sampled = self.greedy_sample(logits) - - assert sampling_metadata.temperature is not None - - # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature) - - # Apply min_p. - if sampling_metadata.min_p is not None: - logits = self.apply_min_p(logits, sampling_metadata.min_p) - - # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( - logits, - sampling_metadata.generators, - sampling_metadata.top_k, - sampling_metadata.top_p, - ) - - sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, - greedy_sampled, random_sampled) - return sampled - - def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: - return logits.log_softmax(dim=-1, dtype=torch.float32) - - def gather_logprobs( - self, - logprobs: torch.Tensor, - num_logprobs: int, - token_ids: torch.Tensor, - ) -> LogprobsTensors: - """ - Gather logprobs for topk and sampled/prompt token. - - Args: - logits: (num tokens) x (vocab) tensor - num_logprobs: minimum number of logprobs to - retain per token - token_ids: prompt tokens (if prompt logprobs) - or sampled tokens (if sampled - logprobs); 1D token ID tensor - with (num tokens) elements - - Returns: - Top-k int indices tensor, (num tokens) x (num_logprobs + 1) - Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) - Sampled token rank tensor, (num tokens) - """ - # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, - num_logprobs, - dim=-1) - - # Get with the logprob of the prompt or sampled token. - token_ids = token_ids.unsqueeze(-1) - token_logprobs = logprobs.gather(-1, token_ids) - - # Compute the ranks of the actual token. - token_ranks = (logprobs >= token_logprobs).sum(-1) - - # Concatenate together with the topk. - indices = torch.cat((token_ids, topk_indices), dim=1) - logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) - - # Use int32 to reduce the tensor size. - indices = indices.to(torch.int32) - - return LogprobsTensors(indices, logprobs, token_ranks) - - def apply_min_p( - self, - logits: torch.Tensor, - min_p: torch.Tensor, - ) -> torch.Tensor: - """ - Filters logits using adaptive probability thresholding. - """ - # Convert logits to probability distribution - probability_values = torch.nn.functional.softmax(logits, dim=-1) - # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) - # Reshape min_p for broadcasting - adjusted_min_p = min_p.unsqueeze(1) * max_probabilities - # Identify valid tokens using threshold comparison - valid_token_mask = probability_values >= adjusted_min_p - # Apply mask using boolean indexing (xla friendly) - logits.masked_fill_(~valid_token_mask, -float("inf")) - return logits From c82130ed7f81158b44f05eb6fb489f060997b14b Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 15 May 2025 17:56:19 +0000 Subject: [PATCH 10/15] fallback to tpu commons worker import Signed-off-by: Siyuan Liu --- vllm/v1/worker/tpu_model_runner.py | 1501 ++++++++++++++++++++++++++++ vllm/v1/worker/tpu_worker.py | 273 ++++- 2 files changed, 1773 insertions(+), 1 deletion(-) create mode 100644 vllm/v1/worker/tpu_model_runner.py diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py new file mode 100644 index 000000000000..687dabee2290 --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner.py @@ -0,0 +1,1501 @@ +# SPDX-License-Identifier: Apache-2.0 +import bisect +import gc +import time +from typing import TYPE_CHECKING, Optional, cast +from unittest.mock import patch + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +# TPU XLA related +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, + PlaceholderRange) +from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasMetadata) +from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, + SlidingWindowSpec) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput) +from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata +from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler +from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin + +from .utils import sanity_check_mm_encoder_outputs + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +INVALID_TOKEN_ID = -1 +# Smallest output size +MIN_NUM_SEQS = 8 + + +######################################################### +# Ways to avoid recompilation +######################################################### +# +# The model executor has two primary components: +# 1. preparing the model and sampler inputs +# 2. executing the model and sampler. +# The core idea is to avoid any TPU computation during input preparation. For +# better compilation tracking and increased flexibility, the model execution and +# sampler are divided into several distinct components. +# +# Below are the detailed steps: +# +# Step 1 +# It is recommended to avoid TPU operations when preparing the model and sampler +# inputs. CPU tensors can be prepared and transferred to the XLA device using +# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids +# compilation. +# +# Step 2 +# The TPU execution should be decomposed into subgraphs (4 at the moment): +# 1. the main model +# 2. selecting hidden states for each request +# 3. sampler +# 4. encoder. +# Each subgraph should be decorated in a torch.compile. This is used to make +# sure that we have the same subgraph topology in both dummy_run and +# xecute_model. The results from these subgraphs should either be passed to +# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for +# subsequent processing on the CPU. +# +# Step 3 +# The dummy_run should be comprehensive, ensuring all potential input shapes and +# branch predictions are included as subgraph inputs to facilitate +# pre-compilation. +class TPUModelRunner(LoRAModelRunnerMixin): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.device_config = vllm_config.device_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION + + self.enforce_eager = model_config.enforce_eager + + self.num_xla_graphs = 0 + self._update_num_xla_graphs("init") + + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + self._hidden_states_dtype = self.dtype + + self.is_multimodal_model = model_config.is_multimodal_model + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + # InputBatch needs to work with sampling tensors greater than padding + # to avoid dynamic shapes. Also, avoid suboptimal alignment. + self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) + self.num_tokens_paddings = _get_token_paddings( + min_token_size=16, + max_token_size=scheduler_config.max_num_batched_tokens, + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + # In case `max_num_tokens < max(num_tokens_paddings)` use the actual + # padded max value to pre-allocate data structures and pre-compile. + self.max_num_tokens = self.num_tokens_paddings[-1] + + # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + self.vocab_size = model_config.get_vocab_size() + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope + # TODO: Support M-RoPE (e.g, Qwen2-VL) + assert not self.uses_mrope, "TPU does not support M-RoPE yet." + + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + mm_registry=self.mm_registry, + ) + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: list[torch.Tensor] = [] + # req_id -> (input_id -> encoder_output) + self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + + # Request states. + self.requests: dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.vocab_size, + ) + + # Cached torch/numpy tensor + # The pytorch tensor and numpy array share the same buffer. + # Sometimes the numpy op is faster so we create both. + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.input_ids_np = self.input_ids_cpu.numpy() + + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.positions_np = self.positions_cpu.numpy() + + self.block_table_cpu = torch.zeros( + (self.max_num_reqs, self.max_num_blocks_per_req), + dtype=self.input_batch.block_table.get_cpu_tensor().dtype, + device="cpu") + + self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + + self.seq_lens_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_lens_np = self.seq_lens_cpu.numpy() + + # Range tensor with values [0 .. self.max_num_tokens - 1]. + # Used to initialize positions / context_lens / seq_lens + # Keep in int64 to avoid overflow with long context + self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) + self.num_reqs_paddings = _get_req_paddings( + min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + + # tensors for structured decoding + self.grammar_bitmask_cpu = torch.zeros( + (self.max_num_reqs, cdiv(self.vocab_size, 32)), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.require_structured_out_cpu = torch.zeros( + (self.max_num_reqs, 1), + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory) + self.structured_decode_arange = torch.arange( + 0, 32, device="cpu", pin_memory=self.pin_memory) + + # Get maximum number of mm items per modality (batch size). + self.max_num_mm_items_by_modality = dict() + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + max_tokens_by_modality_dict = ( + MULTIMODAL_REGISTRY. + get_max_tokens_per_item_by_nonzero_modality(self.model_config)) + for modality, max_tokens in max_tokens_by_modality_dict.items(): + # Check how many items of this modality can be supported by + # the encoder budget. + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + max_num_mm_items_encoder_budget = cdiv(encoder_budget, + max_tokens) + + # Check how many items of this modality can be supported by + # the decoder budget. + max_mm_items_per_req = self.mm_registry.\ + get_mm_limits_per_prompt(self.model_config)[modality] + + # NOTE: We do not consider max_num_batched_tokens on purpose + # because the multimodal embeddings can be generated in advance + # and chunked prefilled. + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req + + max_num_mm_items = min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget) + self.max_num_mm_items_by_modality[modality] = max_num_mm_items + + def _update_num_xla_graphs(self, case_str): + check_comp = self.check_recompilation and not self.enforce_eager + if not check_comp: + return + + total_cached_graphs = xr.get_num_cached_compilation_graph() + new_compiled_graphs = total_cached_graphs - self.num_xla_graphs + if new_compiled_graphs == 0: + return + + logger.info("Add new %d compiled XLA graphs due to %s", + new_compiled_graphs, case_str) + self.num_xla_graphs += new_compiled_graphs + + def _verify_num_xla_graphs(self, case_str): + check_comp = self.check_recompilation and not self.enforce_eager + if not check_comp: + return + + curr_cached_graph = xr.get_num_cached_compilation_graph() + assert self.num_xla_graphs == curr_cached_graph, ( + "Recompilation after warm up is detected during {}." + " num_xla_graphs = {} curr_cached_graph = {}".format( + case_str, self.num_xla_graphs, curr_cached_graph)) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + Returns: + True if there is a new/resumed/paused/finished request. + If False, we can skip copying SamplingMetadata to the GPU. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + removed_req_indices: list[int] = [] + for req_id in scheduler_output.finished_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + req_index = self.input_batch.remove_request(req_id) + assert req_index is not None + removed_req_indices.append(req_index) + + req_ids_to_add: list[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=None, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + for req_data in scheduler_output.scheduled_cached_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + # Update the cached states. + req_state.num_computed_tokens = req_data.num_computed_tokens + if not req_data.resumed_from_preemption: + # Append the new blocks to the existing block IDs. + req_state.block_ids.extend(req_data.new_block_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = req_data.new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 + + def get_model(self) -> nn.Module: + assert self.model is not None + return self.model + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + layers = get_layers_from_vllm_config(self.vllm_config, Attention) + block_size = self.vllm_config.cache_config.block_size + kv_cache_spec: dict[str, KVCacheSpec] = {} + for layer_name, attn_module in layers.items(): + if attn_module.attn_type == AttentionType.DECODER: + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + sliding_window=attn_module.sliding_window, + use_mla=False, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + use_mla=False, + ) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + return kv_cache_spec + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # Get the number of scheduled tokens for each request. + num_scheduled_tokens_per_req = [] + max_num_scheduled_tokens_all_reqs = 0 + for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens_per_req.append(num_tokens) + max_num_scheduled_tokens_all_reqs = max( + max_num_scheduled_tokens_all_reqs, num_tokens) + num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, + dtype=np.int32) + assert max_num_scheduled_tokens_all_reqs > 0 + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + # For each scheduled token, what are the corresponding req index. + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens_per_req) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # For each scheduled token, what is its position in corresponding req. + arange = np.concatenate( + [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) + + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.input_batch.block_table. + slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens_per_req, + out=self.query_start_loc_np[1:num_reqs + 1]) + self.query_start_loc_np[num_reqs + 1:] = 1 + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req) + + # Do the padding and copy the tensors to the TPU. + padded_total_num_scheduled_tokens = _get_padded_token_len( + self.num_tokens_paddings, total_num_scheduled_tokens) + # Zero out to avoid spurious values from prev iteration (last cp chunk) + self.input_ids_cpu[ + total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 + self.input_ids = self.input_ids_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.position_ids = self.positions_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.input_batch.block_table.slot_mapping_cpu[ + total_num_scheduled_tokens:] = _PAD_SLOT_ID + slot_mapping = ( + self.input_batch.block_table. + slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( + self.device)) + block_tables = self.block_table_cpu[:self.max_num_reqs] + block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( + self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) + block_tables = block_tables.to(self.device) + query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( + self.device) + seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) + + if self.lora_config is not None: + # We need to respect padding when activating LoRA adapters + padded_num_scheduled_tokens_per_req = np.copy( + num_scheduled_tokens_per_req + ) # Copying to avoid accidental state corruption bugs + padded_num_scheduled_tokens_per_req[-1] += \ + padded_total_num_scheduled_tokens - total_num_scheduled_tokens + + self.set_active_loras(self.input_batch, + padded_num_scheduled_tokens_per_req) + + attn_metadata = PallasMetadata( + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=seq_lens, + query_start_loc=query_start_loc, + num_seqs=torch.tensor([num_reqs], + dtype=torch.int32, + device=self.device), + ) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + padded_num_reqs = _get_padded_num_reqs_with_upper_limit( + num_reqs, self.max_num_reqs) + # Indices at which we sample (positions of last token in the sequence). + # Padded to avoid recompiling when `num_reqs` varies. + logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 + logits_indices = logits_indices.to(self.device) + + layer_names = get_layers_from_vllm_config(self.vllm_config, + Attention).keys() + per_layer_attn_metadata = { + layer_name: attn_metadata + for layer_name in layer_names + } + return per_layer_attn_metadata, logits_indices, padded_num_reqs + + def _scatter_placeholders( + self, + embeds: torch.Tensor, + is_embed: Optional[torch.Tensor], + ) -> torch.Tensor: + if is_embed is None: + return embeds + + placeholders = embeds.new_full( + (is_embed.shape[0], embeds.shape[-1]), + fill_value=torch.nan, + ) + placeholders[is_embed] = embeds + return placeholders + + def _gather_placeholders( + self, + placeholders: torch.Tensor, + is_embed: Optional[torch.Tensor], + ) -> torch.Tensor: + if is_embed is None: + return placeholders + + return placeholders[is_embed] + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs = list[MultiModalKwargs]() + req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + + for mm_input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[mm_input_id]) + req_ids_pos.append( + (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, + # we process it separately to preserve item order. + # FIXME(ywang96): This is a hacky way to deal with multiple modalities + # in the same batch while still being able to benefit from batching + # multimodal inputs. The proper solution should be reordering the + # encoder outputs. + grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) + + encoder_outputs = [] + for grouped_mm_inputs in grouped_mm_inputs_list: + batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, each of shape + # (feature_size, hidden_size) in case the feature size is dynamic + # depending on the input multimodal items. + xm.mark_step() + curr_group_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) + xm.mark_step() + + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) + + if isinstance(curr_group_outputs, torch.Tensor): + encoder_outputs.append(curr_group_outputs) + else: + assert isinstance(curr_group_outputs, (list, tuple)) + for output in curr_group_outputs: + encoder_outputs.append(output) + + # Cache the encoder outputs. + # NOTE (NickLucche) here we diverge from logic in other runners, as we + # assume to only have whole mm items to process. Hence we avoid the + # intrinsic dynamism that `scatter_mm_placeholders` introduces. + for (req_id, input_id, pos_info), output in zip( + req_ids_pos, + encoder_outputs, + ): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + assert pos_info.is_embed is None, "Expected all positions to be"\ + " contiguous and embeddings." + self.encoder_cache[req_id][input_id] = output + + def _gather_mm_embeddings( + self, + scheduler_output: "SchedulerOutput", + ) -> list[torch.Tensor]: + mm_embeds: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + # TODO unroll loop and assume/enforce --disable_chunked_mm_input + # NOTE (NickLucche) here we diverge from logic in other runners, as + # we assume to only have whole mm items to process. Hence we avoid + # the intrinsic dynamism that `gather_mm_placeholders` introduces. + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info.offset + num_encoder_tokens = pos_info.length + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + assert pos_info.is_embed is None, "Expected all positions to"\ + " be contiguous and embeddings." + encoder_output = self.encoder_cache[req_id][i] + mm_embeds.append(encoder_output) + return mm_embeds + + def _get_model_inputs(self, input_ids: torch.Tensor, + mm_embeds: list[torch.Tensor]): + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + return None, inputs_embeds + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + return input_ids, None + + @torch.no_grad() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> ModelRunnerOutput: + # Update cached state + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + xm.mark_step() + # Prepare inputs + attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs( + scheduler_output) + input_ids, inputs_embeds = self._get_model_inputs( + self.input_ids, mm_embeds) + xm.mark_step() + num_reqs = self.input_batch.num_reqs + # Run the decoder + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=scheduler_output.total_num_scheduled_tokens): + hidden_states = self.model( + input_ids=input_ids, + positions=self.position_ids, + inputs_embeds=inputs_embeds, + ) + hidden_states = self.select_hidden_states(hidden_states, + logits_indices) + logits = self.compute_logits(hidden_states) + tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ + from_input_batch(self.input_batch, padded_num_reqs, self.device) + if scheduler_output.grammar_bitmask is not None: + require_struct_decoding, grammar_bitmask_padded, arange = \ + self.prepare_structured_decoding_input(logits, scheduler_output) + logits = self.structured_decode(require_struct_decoding, + grammar_bitmask_padded, logits, + arange) + selected_token_ids = self.sample_from_logits(logits, + tpu_sampling_metadata) + + # NOTE (NickLucche) Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. We can't enforce it due + # to recompilations outside torch.compiled code, so just make sure + # `sample_from_logits` does not modify the logits in-place. + logprobs = self.gather_logprobs(logits, selected_token_ids) \ + if tpu_sampling_metadata.logprobs else None + + # Remove padding on cpu and keep dynamic op outside of xla graph. + selected_token_ids = selected_token_ids.cpu()[:num_reqs] + logprobs_lists = logprobs.tolists() \ + if tpu_sampling_metadata.logprobs else None + + # Update the cache state concurrently. Code above will not block until + # we use `selected_token_ids`. Add mark_step if post-processing changes + request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] + discard_sampled_tokens_req_indices = [] + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): + assert req_id is not None + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len >= req_state.num_tokens: + request_seq_lens.append((i, req_state, seq_len)) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) + + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) + + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) + + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + for req_id in self.input_batch.req_ids[:num_reqs]: + prompt_logprobs_dict[req_id] = None + + max_gen_len = selected_token_ids.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = selected_token_ids.tolist() + + # Mask out the sampled tokens that should not be sampled. + # TODO: Keep in sync with gpu_model_runner.py, in particular + # the "else" case here + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + # Append sampled tokens + for i, req_state, seq_len in request_seq_lens: + token_id = valid_sampled_token_ids[i][0] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + self.input_batch.num_tokens[i] += 1 + + else: + valid_mask = selected_token_ids != INVALID_TOKEN_ID + gen_lens = valid_mask.sum(dim=1).tolist() + valid_sampled_token_ids = [ + seq.tolist() + for seq in selected_token_ids[valid_mask].split(gen_lens) + ] + self.input_batch.num_tokens[:num_reqs] += gen_lens + for i, req_state, seq_len in request_seq_lens: + target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) + self.input_batch.token_ids_cpu[ + i, target_slice] = valid_sampled_token_ids[i] + req_state.output_token_ids.extend(valid_sampled_token_ids[i]) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=None, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + ) + + # Check there are no new graphs compiled - all the graphs should be + # captured and compiled during warm up. + self._verify_num_xla_graphs("execute_model") + + return model_runner_output + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + if self.lora_config is not None: + model = self.load_lora_model(model, self.model_config, + self.scheduler_config, + self.lora_config, self.device) + + # Sync all pending XLA execution during model initialization and weight + # loading. + xm.mark_step() + xm.wait_device_ops() + self.model = model + self.sampler = TPUSampler() + + @torch.no_grad() + def _dummy_run(self, num_tokens: int) -> None: + if self.is_multimodal_model: + input_ids = None + inputs_embeds = torch.zeros((num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) + else: + input_ids = torch.zeros((num_tokens), + dtype=torch.int32, + device=self.device) + inputs_embeds = None + actual_num_reqs = min(num_tokens, self.max_num_reqs) + position_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros(num_tokens, + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (self.max_num_reqs, self.block_table_cpu.shape[1]), + dtype=torch.int32, + device=self.device) + query_lens = [1] * self.max_num_reqs + query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, + dtype=torch.int32), + dim=0, + dtype=torch.int32).to(self.device) + context_lens = torch.ones((self.max_num_reqs, ), + dtype=torch.int32, + device=self.device) + num_seqs = torch.tensor([actual_num_reqs], + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + query_start_loc=query_start_loc, + num_seqs=num_seqs, + ) + + if self.is_multimodal_model: + torch._dynamo.mark_dynamic(inputs_embeds, 0) + else: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + + layer_names = get_layers_from_vllm_config(self.vllm_config, + Attention).keys() + per_layer_attn_metadata = { + layer_name: attn_metadata + for layer_name in layer_names + } + + with self.maybe_dummy_run_with_lora( + self.lora_config, + np.array([num_tokens], dtype=np.int32)), set_forward_context( + per_layer_attn_metadata, self.vllm_config, 0): + out = self.model(input_ids=input_ids, + positions=position_ids, + inputs_embeds=inputs_embeds) + self._hidden_states_dtype = out.dtype + + def _precompile_mm_encoder(self) -> None: + # Pre-compile MM encoder for all supported data modalities. + hf_config = self.vllm_config.model_config.hf_config + for mode, max_items_by_mode in \ + self.max_num_mm_items_by_modality.items(): + logger.info( + "Compiling Multimodal %s Encoder with different input" + " shapes.", mode) + start = time.perf_counter() + # No padding for MM encoder just yet. + for num_items in range(1, max_items_by_mode + 1): + logger.info(" -- mode: %s items: %d", mode, num_items) + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + mode, num_items) + # Run multimodal encoder. + xm.mark_step() + mm_embeds = self.model.\ + get_multimodal_embeddings(**batched_dummy_mm_inputs) + xm.mark_step() + num_patches = mm_embeds[0].shape[0] + items_size = num_patches * num_items + + # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm + # embeddings are present. We assume `--disable-mm-chunked`, + # hence only whole items can be scheduled. This implies we just + # need to compile when `num_items` fit the (padded) `input_ids` + for num_tokens in self.num_tokens_paddings: + if num_tokens >= items_size: + # XLA Workaround: if torch.zeros(..device) is used, XLA + # compiles a scalar+expansion op, which won't match + # the graph generated at runtime. CPU->TPU must be used + placeholders_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device="cpu") + # Align placeholders and actual num mm_embeddings. + placeholders_ids[:items_size] = \ + hf_config.image_token_index + + placeholders_ids = placeholders_ids.to(self.device) + # Assign outputs or the graph will be cut short. + a, b = self._get_model_inputs(placeholders_ids, + [mm_embeds]) + assert a is None + xm.mark_step() + + # Pre-compile `get_input_embeddings` when mm_embeddings are not + # present. Chunk is only made of text, no mm_placeholders. + for num_tokens in self.num_tokens_paddings: + placeholders_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device="cpu") + placeholders_ids = placeholders_ids.to(self.device) + a, b = self._get_model_inputs(placeholders_ids, []) + assert a is None + xm.mark_step() + + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal %s Encoder compilation finished in in %.2f " + "[secs].", mode, end - start) + + def _precompile_backbone(self) -> None: + logger.info("Compiling the model with different input shapes.") + start = time.perf_counter() + for num_tokens in self.num_tokens_paddings: + logger.info(" -- num_tokens: %d", num_tokens) + self._dummy_run(num_tokens) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("model backbone") + + def _precompile_select_hidden_states(self) -> None: + # Compile hidden state selection function for bucketed + # n_tokens x max_num_reqs. Graph is really small so this is fine. + logger.info( + "Compiling select_hidden_states with different input shapes.") + start = time.perf_counter() + hsize = self.model_config.get_hidden_size() + for num_tokens in self.num_tokens_paddings: + dummy_hidden = torch.zeros((num_tokens, hsize), + device=self.device, + dtype=self._hidden_states_dtype) + torch._dynamo.mark_dynamic(dummy_hidden, 0) + for num_reqs in self.num_reqs_paddings: + indices = torch.zeros(num_reqs, + dtype=torch.int32, + device=self.device) + torch._dynamo.mark_dynamic(indices, 0) + self.select_hidden_states(dummy_hidden, indices) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, + num_reqs) + # Requests can't be more than tokens. But do compile for the + # next bigger value in case num_tokens uses bucketed padding. + if num_reqs >= min(num_tokens, self.max_num_reqs): + break + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("select_hidden_states") + + def _precompile_compute_logits(self) -> None: + logger.info("Compiling compute_logits with different input shapes.") + start = time.perf_counter() + hsize = self.model_config.get_hidden_size() + for num_reqs in self.num_reqs_paddings: + dummy_hidden = torch.zeros((num_reqs, hsize), + device=self.device, + dtype=self._hidden_states_dtype) + torch._dynamo.mark_dynamic(dummy_hidden, 0) + self.compute_logits(dummy_hidden) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("compute_logits") + + def _precompile_structured_decoding(self) -> None: + logger.info( + "Compiling structured_decoding with different input shapes.") + start = time.perf_counter() + for num_reqs in self.num_reqs_paddings: + dummy_logits = torch.zeros((num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype) + dummy_require_struct_decoding = \ + self.require_structured_out_cpu[:num_reqs].to(self.device) + dummy_grammar_bitmask = \ + self.grammar_bitmask_cpu[:num_reqs].to(self.device) + # The first dimension of the above 3 dummy tensors cannot be + # mark_dynamic because some operations in structured_decode require + # them to be static. + arange = self.structured_decode_arange.to(self.device) + self.structured_decode(dummy_require_struct_decoding, + dummy_grammar_bitmask, dummy_logits, arange) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("structured_decoding") + + def _precompile_sample_from_logits(self) -> None: + logger.info( + "Compiling sample_from_logits with different input shapes.") + start = time.perf_counter() + for num_reqs in self.num_reqs_paddings: + dummy_logits = torch.zeros((num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype) + # The first dimension of dummy_logits cannot be mark_dynamic + # because some operations in the sampler require it to be static. + for all_greedy in [False, True]: + generate_params_if_all_greedy = not all_greedy + sampling_metadata = ( + TPUSupportedSamplingMetadata.from_input_batch( + self.input_batch, + num_reqs, + self.device, + generate_params_if_all_greedy, + )) + sampling_metadata.all_greedy = all_greedy + self.sample_from_logits(dummy_logits, sampling_metadata) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("sample_from_logits") + + def _precompile_gather_logprobs(self) -> None: + logger.info("Compiling gather_logprobs with different input shapes.") + start = time.perf_counter() + for num_reqs in self.num_reqs_paddings: + dummy_logits = torch.zeros((num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype) + dummy_tokens = torch.zeros((num_reqs, 1), + dtype=torch.int64).to(self.device) + self.gather_logprobs(dummy_logits, dummy_tokens) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("gather_logprobs") + + def capture_model(self) -> None: + """ + Precompile all the subgraphs with possible input shapes. + """ + self._precompile_mm_encoder() + self._precompile_backbone() + self._precompile_select_hidden_states() + self._precompile_compute_logits() + self._precompile_structured_decoding() + self._precompile_sample_from_logits() + self._precompile_gather_logprobs() + + def profile_run( + self, + num_tokens: int, + ) -> None: + # Profile with multimodal encoder & encoder cache. + # TODO: handle encoder-decoder models once we support them. + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_data_modality, max_num_mm_items = max( + self.max_num_mm_items_by_modality.items(), key=lambda t: t[1]) + + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + logger.info( + "Encoder cache will be initialized with a budget of %d tokens," + " and profiled with %s %s items of the maximum feature size.", + encoder_budget, max_num_mm_items, dummy_data_modality) + + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_data_modality, max_num_mm_items) + + # Run multimodal encoder. + # Isolate encoder graph from post-processing to minimize + # impact of recompilation until it's fixed. + start = time.perf_counter() + xm.mark_step() + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal Encoder profiling finished in in %.2f [secs].", + end - start) + + assert len(dummy_encoder_outputs) == max_num_mm_items, ( + "Expected dimension 0 of encoder outputs to match the number " + f"of multimodal data items: {max_num_mm_items}, got " + f"{len(dummy_encoder_outputs)=} instead. This is most likely " + "due to the 'get_multimodal_embeddings' method of the model " + "not implemented correctly.") + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + + # Trigger compilation for general shape. + self._dummy_run(num_tokens) + + xm.mark_step() + xm.wait_device_ops() + self.encoder_cache.clear() + gc.collect() + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + if len(kv_cache_config.kv_cache_groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") + + kv_caches: dict[str, torch.Tensor] = {} + + for kv_cache_group in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, AttentionSpec): + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + tpu_kv_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + + kv_caches[layer_name] = tpu_kv_cache + else: + raise NotImplementedError + + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + + def reset_dynamo_cache(self): + if self.is_multimodal_model: + compiled_model = self.model.get_language_model().model + else: + compiled_model = self.model.model + if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): + logger.info("Clear dynamo cache and cached dynamo bytecode.") + torch._dynamo.eval_frame.remove_from_cache( + compiled_model.original_code_object) + compiled_model.compiled_codes.clear() + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def select_hidden_states(self, hidden_states, indices_do_sample): + return hidden_states[indices_do_sample] + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def compute_logits(self, + sample_hidden_states: torch.Tensor) -> torch.Tensor: + return self.model.compute_logits(sample_hidden_states, None) + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def sample_from_logits( + self, logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: + """ + Sample with xla-friendly function. This function is to be traced + separately from `forward` for lighter compilation overhead. + """ + if sampling_metadata.all_greedy: + out_tokens = torch.argmax(logits, dim=-1, keepdim=True) + else: + out_tokens = self.sampler(logits, + sampling_metadata).sampled_token_ids + return out_tokens + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def gather_logprobs(self, logits: torch.Tensor, + sampled_tokens: torch.Tensor) -> LogprobsTensors: + """ + Gather the top_logprobs with corresponding tokens. Use a fixed number + of logprobs as an alternative to having multiple pre-compiled graphs. + Select the number of logprobs actually demanded by each request on CPU. + """ + logprobs = self.sampler.compute_logprobs(logits) + return self.sampler.gather_logprobs( + logprobs, + self.model_config.max_logprobs, + token_ids=sampled_tokens.squeeze(-1)) + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def structured_decode(self, require_struct_decoding: torch.Tensor, + grammar_bitmask: torch.Tensor, logits: torch.Tensor, + arange: torch.Tensor) -> torch.Tensor: + return torch.where( + require_struct_decoding, + self.apply_grammar_bitmask(logits, grammar_bitmask, arange), + logits) + + def apply_grammar_bitmask(self, logits: torch.Tensor, + grammar_bitmask: torch.Tensor, + arange: torch.Tensor): + assert (logits.shape[0] == grammar_bitmask.shape[0]) + logits_cloned = logits.clone() + for i in range(logits.shape[0]): + unpacked_bitmask = (torch.bitwise_right_shift( + grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 + unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size] + logits_cloned[i] = logits_cloned[i].masked_fill( + unpacked_bitmask, -float("inf")) + return logits_cloned + + def get_multimodal_embeddings(self, *args, **kwargs): + return self.model.get_multimodal_embeddings(*args, **kwargs) + + def get_input_embeddings(self, *args, **kwargs): + return self.model.get_input_embeddings(*args, **kwargs) + + def prepare_structured_decoding_input( + self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + grammar_bitmask = scheduler_output.grammar_bitmask + assert grammar_bitmask is not None + num_reqs, _ = logits.shape + + # Reset pre-allocated tensors + self.grammar_bitmask_cpu.zero_() + self.require_structured_out_cpu.zero_() + + # We receive the structured output bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the tpu runner is + # ordering the requests in the batch. We need to match the order of + # bitmask with the order of requests + struct_out_indices: list[int] = [] + mask_indices: list[int] = [] + for req_id in self.input_batch.req_ids: + mask_index = scheduler_output.structured_output_request_ids.get( + req_id) + if mask_index is None: + continue + batch_index = self.input_batch.req_id_to_index[req_id] + struct_out_indices.append(batch_index) + mask_indices.append(mask_index) + self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( + grammar_bitmask[mask_indices]) + # It's not guaranteed that all requests in this batch require + # structured output, so create a bool tensor to represent + # the requests that need structured output. + struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) + self.require_structured_out_cpu[struct_out_indices] = True + return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ + self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ + self.structured_decode_arange.to(logits.device) + + def _get_mm_dummy_batch(self, modality: str, + batch_size: int) -> BatchedTensorInputs: + # Dummy data for pre-compiling multimodal models. + dummy_request_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_num_tokens, + ) + dummy_mm_data = dummy_request_data.multi_modal_data + + # Dummy data definition in V0 may contain multiple multimodal items + # (e.g, multiple images) for a single request, therefore here we + # always replicate first item by max_num_mm_items times since in V1 + # they are scheduled to be processed separately. + assert isinstance(dummy_mm_data, MultiModalKwargs), ( + "Expected dummy multimodal data to be of type " + f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. " + "This is most likely due to the model not having a merged " + "processor.") + + # When models have a merged processor, their dummy data is + # already batched `MultiModalKwargs`, therefore we take the first + # `MultiModalKwargsItem` from the desired modality to profile on. + dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + + batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * + batch_size) + return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, + device=self.device) + + +def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: + logger.info("Preparing request paddings:") + # assert min_req_size is power of 2 + assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0 + paddings: list = [] + num = max(MIN_NUM_SEQS, min_req_size) + while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num): + paddings.append(num) + logger.info(" %d", num) + num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size) + return paddings + + +def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int: + res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() + return min(res, upper_limit) + + +def _get_token_paddings(min_token_size: int, max_token_size: int, + padding_gap: int) -> list[int]: + """Generate a list of padding size, starting from min_token_size, + ending with a number that can cover max_token_size + + If padding_gap == 0 then: + increase 2X each time (exponential) + else: + first increase the size to twice, + then increase the padding size by padding_gap. + """ + # assert min_token_size is power of 2 + assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0 + paddings = [] + num = min_token_size + + if padding_gap == 0: + logger.info("Using exponential token paddings:") + while True: + logger.info(" %d", num) + paddings.append(num) + if num >= max_token_size: + break + num *= 2 + else: + logger.info("Using incremental token paddings:") + while num <= padding_gap: + logger.info(" %d", num) + paddings.append(num) + num *= 2 + num //= 2 + while num < max_token_size: + num += padding_gap + logger.info(" %d", num) + paddings.append(num) + + return paddings + + +def _get_padded_token_len(paddings: list[int], x: int) -> int: + """Return the first element in paddings list greater or equal to x. + """ + index = bisect.bisect_left(paddings, x) + assert index < len(paddings) + return paddings[index] diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 7189399efbe6..ba6139bf992f 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,4 +1,275 @@ # SPDX-License-Identifier: Apache-2.0 """A TPU worker class.""" +import os +from typing import Optional -from tpu_commons.worker import TPUWorker # noqa: F401 +import torch +import torch.distributed +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, + KVCacheSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.utils import bind_kv_cache, report_usage_stats +from vllm.v1.worker.tpu_model_runner import TPUModelRunner + +logger = init_logger(__name__) + + +class TPUWorker: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + self.is_driver_worker = is_driver_worker + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Delay profiler initialization to the start of the profiling. + # This is because in vLLM V1, MP runtime is initialized before the + # TPU Worker is initialized. The profiler server needs to start after + # MP runtime is initialized. + self.profiler = None + self.profile_dir = None + if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + # For TPU, we can only have 1 active profiler session for 1 profiler + # server. So we only profile on rank0. + self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + self.profile_dir) + + if self.model_config.seed is None: + self.model_config.seed = 0 + + if vllm_config.lora_config is not None: + raise NotImplementedError( + "The V1 TPU backend doesn't support LoRA serving") + + def init_device(self): + os.environ["PJRT_DEVICE"] = "TPU" + # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D + # ring, the xla tpu compiler flag + # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to + # fix this. It will be removed after the bug in XLA compiler is fixed. + os.environ["LIBTPU_INIT_ARGS"] = ( + "--xla_tpu_force_1d_allreduce_at_chunk_count=1") + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # Initialize the distributed environment. + init_tpu_worker_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + self.local_rank) + + # Device initialization should happen after initializing + # the distributed runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + if self.model_config.seed is not None: + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # TODO (NickLucche) On gsm we compile 80+ graphs. + # Re-evaluate limit, with MM we may get close to this limit. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. + # Consequently, changes in optimization flags, which affect compilation + # results, don't change the cache key. This can result in the wrong + # compilation being used. To prevent this, disabling the XLA compilation + # cache during development is recommended.We can disable it by + # `export VLLM_XLA_CACHE_PATH=` + if envs.VLLM_XLA_CACHE_PATH: + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config, self.device) + + if rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) + + def determine_available_memory(self) -> int: + kv_caches: dict[str, torch.Tensor] = {} + kv_cache_spec = self.model_runner.get_kv_cache_spec() + for layer_name, layer_spec in kv_cache_spec.items(): + if isinstance(layer_spec, AttentionSpec): + dtype = layer_spec.dtype + + # Use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + tpu_kv_cache = torch.tensor([], + dtype=dtype, + device=self.device) + kv_caches[layer_name] = tpu_kv_cache + else: + raise NotImplementedError( + f"Unsupported KV cache spec '{type(layer_spec)}'") + + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + runner_kv_caches) + + # `max_num_tokens >= max_num_batched_tokens` due to padding. + self.model_runner.profile_run(self.model_runner.max_num_tokens) + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # During the profiling run, the model runs without KV cache. After + # the profiling run, the model always runs with KV cache. Here we clear + # the dynamo cache and cached bytecode to ensure the model always has + # one compiled bytecode. Having one FX graph/cached bytecode per + # compiled model is required for `support_torch_compile` decorator to + # skip dynamo guard. + self.model_runner.reset_dynamo_cache() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + current_mem = m["bytes_used"] + # Ideally we would use profiled = m["peak_bytes_used"] to + # get weights + activations. But there is memory used during + # compilation / weight loading that impacts the peak and + # there is no way to reset peak memory in XLA, So we + # use the heuristic of 2% of weights. + profiled = current_mem * 1.02 + + # Calculate the TPU KV cache size based on profiling. + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + + return int(tpu_kv_cache_bytes) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + output = self.model_runner.execute_model(scheduler_output) + return output if self.is_driver_worker else None + + def profile(self, is_start: bool = True): + if self.rank < 1: + if self.profile_dir is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + if self.profiler is None: + self.profiler = xp.start_server(9012) + xp.start_trace(self.profile_dir) + else: + xp.stop_trace() + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def load_model(self) -> None: + self.model_runner.load_model() + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def get_model(self) -> nn.Module: + return self.model_runner.get_model() + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + return self.model_runner.get_kv_cache_spec() + + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + self.model_runner.initialize_kv_cache(kv_cache_config) + + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + + +def init_tpu_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + local_rank=local_rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +try: + from tpu_commons.worker import TPUWorker as TPUCommonsWorker + TPUWorker = TPUCommonsWorker # type: ignore +except ImportError: + pass From 00710d9ae7285104a2829593a506097096dd46d6 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 15 May 2025 18:32:16 +0000 Subject: [PATCH 11/15] Revert "remove sampler" This reverts commit 3b4fc66e7daa3a79eb6669e2e97c359bc8c19539. Signed-off-by: Siyuan Liu --- vllm/v1/sample/tpu/__init__.py | 0 vllm/v1/sample/tpu/metadata.py | 123 ++++++++++++++++++++++++++++ vllm/v1/sample/tpu/sampler.py | 144 +++++++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+) create mode 100644 vllm/v1/sample/tpu/__init__.py create mode 100644 vllm/v1/sample/tpu/metadata.py create mode 100644 vllm/v1/sample/tpu/sampler.py diff --git a/vllm/v1/sample/tpu/__init__.py b/vllm/v1/sample/tpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py new file mode 100644 index 000000000000..a1c7dcdb111f --- /dev/null +++ b/vllm/v1/sample/tpu/metadata.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Optional + +import torch + +from vllm.v1.worker.gpu_input_batch import InputBatch + +DEFAULT_SAMPLING_PARAMS = dict( + temperature=-1.0, + min_p=0.0, + # strictly disabled for now + top_k=0, + top_p=1.0, + # frequency_penalties=0.0, + # presence_penalties=0.0, + # repetition_penalties=0.0, +) + + +@dataclass +class TPUSupportedSamplingMetadata: + # This class exposes a more xla-friendly interface than SamplingMetadata + # on TPU, in particular all arguments should be traceable and no optionals + # are allowed, to avoid graph recompilation on Nones. + temperature: torch.Tensor = None + + min_p: torch.Tensor = None + top_k: torch.Tensor = None + top_p: torch.Tensor = None + + all_greedy: bool = True + + # Whether logprobs are to be gathered in this batch of request. To balance + # out compile time and runtime, a fixed `max_number_logprobs` value is used + # when gathering logprobs, regardless of the values specified in the batch. + logprobs: bool = False + + # TODO No penalties for now + no_penalties: bool = True + prompt_token_ids = None + frequency_penalties = None + presence_penalties = None + repetition_penalties = None + # should use tensor + output_token_ids: list[list[int]] = field(default_factory=lambda: list()) + + min_tokens = None # impl is not vectorized + + logit_bias: list[Optional[dict[int, float]]] = field( + default_factory=lambda: list()) + + allowed_token_ids_mask = None + bad_words_token_ids = None + + # Generator not supported by xla + _generators: dict[int, + torch.Generator] = field(default_factory=lambda: dict()) + + @property + def generators(self) -> dict[int, torch.Generator]: + # Generator not supported by torch/xla. This field must be immutable. + return self._generators + + @classmethod + def from_input_batch( + cls, + input_batch: InputBatch, + padded_num_reqs: int, + xla_device: torch.device, + generate_params_if_all_greedy: bool = False + ) -> "TPUSupportedSamplingMetadata": + """ + Copy sampling tensors slices from `input_batch` to on device tensors. + + `InputBatch._make_sampling_metadata` causes recompilation on XLA as it + slices dynamic shapes on device tensors. This impl moves the dynamic + ops to CPU and produces tensors of fixed `padded_num_reqs` size. + + Args: + input_batch: The input batch containing sampling parameters. + padded_num_reqs: The padded number of requests. + xla_device: The XLA device. + generate_params_if_all_greedy: If True, generate sampling parameters + even if all requests are greedy. this is useful for cases where + we want to pre-compile a graph with sampling parameters, even if + they are not strictly needed for greedy decoding. + """ + needs_logprobs = input_batch.max_num_logprobs>0 if \ + input_batch.max_num_logprobs else False + # Early return to avoid unnecessary cpu to tpu copy + if (input_batch.all_greedy is True + and generate_params_if_all_greedy is False): + return cls(all_greedy=True, logprobs=needs_logprobs) + + num_reqs = input_batch.num_reqs + + def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: + # Pad value is the default one. + cpu_tensor[num_reqs:padded_num_reqs] = fill_val + + fill_slice(input_batch.temperature_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["temperature"]) + fill_slice(input_batch.min_p_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["min_p"]) + fill_slice(input_batch.top_k_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["top_k"]) + fill_slice(input_batch.top_p_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["top_p"]) + + # Slice persistent device tensors to a fixed pre-compiled padded shape. + return cls( + temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs]. + to(xla_device), + all_greedy=input_batch.all_greedy, + # TODO enable more and avoid returning None values + top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to( + xla_device), + top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( + xla_device), + min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( + xla_device), + logprobs=needs_logprobs) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py new file mode 100644 index 000000000000..7c31a2984b30 --- /dev/null +++ b/vllm/v1/sample/tpu/sampler.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Sampler layer implementing TPU supported operations.""" + +import torch +import torch.nn as nn + +from vllm.v1.outputs import LogprobsTensors, SamplerOutput +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler +from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata + +_SAMPLING_EPS = 1e-5 + + +class Sampler(nn.Module): + + def __init__(self): + super().__init__() + self.topk_topp_sampler = TopKTopPSampler() + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata, + ) -> SamplerOutput: + # Use float32 for the logits. + logits = logits.to(torch.float32) + # Sample the next token. + sampled = self.sample(logits, sampling_metadata) + + # These are TPU tensors. + sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. + sampled_token_ids=sampled.unsqueeze(-1), + logprobs_tensors=None) + return sampler_output + + def apply_temperature( + self, + logits: torch.Tensor, + temp: torch.Tensor, + ) -> torch.Tensor: + return logits.div_(temp.unsqueeze(dim=1)) + + def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + return logits.argmax(dim=-1).view(-1) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata, + ) -> torch.Tensor: + greedy_sampled = self.greedy_sample(logits) + + assert sampling_metadata.temperature is not None + + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + + # Apply min_p. + if sampling_metadata.min_p is not None: + logits = self.apply_min_p(logits, sampling_metadata.min_p) + + # Apply top_k and/or top_p. + random_sampled = self.topk_topp_sampler( + logits, + sampling_metadata.generators, + sampling_metadata.top_k, + sampling_metadata.top_p, + ) + + sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, random_sampled) + return sampled + + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return logits.log_softmax(dim=-1, dtype=torch.float32) + + def gather_logprobs( + self, + logprobs: torch.Tensor, + num_logprobs: int, + token_ids: torch.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logits: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + # Find the topK values. + topk_logprobs, topk_indices = torch.topk(logprobs, + num_logprobs, + dim=-1) + + # Get with the logprob of the prompt or sampled token. + token_ids = token_ids.unsqueeze(-1) + token_logprobs = logprobs.gather(-1, token_ids) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + # Concatenate together with the topk. + indices = torch.cat((token_ids, topk_indices), dim=1) + logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) + + # Use int32 to reduce the tensor size. + indices = indices.to(torch.int32) + + return LogprobsTensors(indices, logprobs, token_ranks) + + def apply_min_p( + self, + logits: torch.Tensor, + min_p: torch.Tensor, + ) -> torch.Tensor: + """ + Filters logits using adaptive probability thresholding. + """ + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + valid_token_mask = probability_values >= adjusted_min_p + # Apply mask using boolean indexing (xla friendly) + logits.masked_fill_(~valid_token_mask, -float("inf")) + return logits From 847cb3444aea5afa02d47e97024a541be6304756 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 15 May 2025 18:32:30 +0000 Subject: [PATCH 12/15] Revert "import tpu platform from tpu commons" This reverts commit 996c8b3d3ea835fedc9af4eebe433f276ff9a286. Signed-off-by: Siyuan Liu --- vllm/platforms/tpu.py | 195 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 194 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4121b691bc49..d0a5af3587c4 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,3 +1,196 @@ # SPDX-License-Identifier: Apache-2.0 -from tpu_commons.platforms import TpuPlatform # noqa: F401 +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast + +import torch +from tpu_info import device + +import vllm.envs as envs +from vllm.inputs import ProcessorInputs, PromptType +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams, SamplingType + +from .interface import Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import BlockSize, ModelConfig, VllmConfig + from vllm.pooling_params import PoolingParams +else: + BlockSize = None + ModelConfig = None + VllmConfig = None + PoolingParams = None + +logger = init_logger(__name__) + + +class TpuPlatform(Platform): + _enum = PlatformEnum.TPU + device_name: str = "tpu" + device_type: str = "tpu" + dispatch_key: str = "XLA" + ray_device_key: str = "TPU" + device_control_env_var: str = "TPU_VISIBLE_CHIPS" + simple_compile_backend: str = "openxla" + + supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"] + + additional_env_vars: list[str] = [ + "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" + ] + + @classmethod + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool, + use_mla: bool) -> str: + if (selected_backend != _Backend.PALLAS + and selected_backend != _Backend.PALLAS_VLLM_V1): + logger.info("Cannot use %s backend on TPU.", selected_backend) + + if use_v1: + logger.info("Using Pallas V1 backend.") + return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + else: + logger.info("Using Pallas backend.") + return "vllm.attention.backends.pallas.PallasAttentionBackend" + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + chip_type, _ = device.get_local_chips() + return f"TPU {chip_type.name}" + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return not envs.VLLM_USE_V1 + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + return torch.finfo(dtype).min, torch.finfo(dtype).max + + @classmethod + def can_update_inplace(cls): + return False + + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + return 1 + + @classmethod + def inference_mode(cls): + return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm.config import CompilationLevel + + cache_config = vllm_config.cache_config + # For v0, the default block size is 16. + if cache_config and cache_config.block_size is None: + cache_config.block_size = cast(BlockSize, 16) + compilation_config = vllm_config.compilation_config + + # TPU only supports DYNAMO_ONCE compilation level + if compilation_config.level != CompilationLevel.DYNAMO_ONCE: + logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") + compilation_config.level = CompilationLevel.DYNAMO_ONCE + + if compilation_config.backend == "": + compilation_config.backend = "openxla" + + assert vllm_config.speculative_config is None, \ + "TPU does not support speculative decoding" + + if vllm_config.model_config.dtype in (torch.float16, torch.float32): + logger.warning( + "The TPU backend currently does not support %s. " + "Using bfloat16 instead.", vllm_config.model_config.dtype) + vllm_config.model_config.dtype = torch.bfloat16 + + if envs.VLLM_USE_V1: + from vllm.v1.attention.backends.pallas import ( + PallasAttentionBackend) + cache_config.block_size = PallasAttentionBackend.get_page_size( + vllm_config) # type: ignore[assignment] + min_page_size = PallasAttentionBackend.get_min_page_size( + vllm_config) + if min_page_size > cache_config.block_size: + logger.warning( + "Increase the page size from %s to %s to make sure there's" + "no SMEM OOM", + cache_config.block_size, + min_page_size, + ) + cache_config.block_size = min_page_size # type: ignore[assignment] + + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on vLLM V1. Please launch without " + "--num-scheduler-steps.") + else: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.tpu_worker.TPUWorker" + else: + parallel_config.worker_cls = \ + "vllm.worker.tpu_worker.TPUWorker" + + assert not vllm_config.speculative_config, ( + "Speculative decoding is not yet supported for TPU backend") + + if scheduler_config.is_multimodal_model and not \ + scheduler_config.disable_chunked_mm_input: + logger.warning("TPU does not support running Multimodal models"\ + " without setting `--disable_chunked_mm_input`. " \ + "Forcing --disable_chunked_mm_input.") + scheduler_config.disable_chunked_mm_input = True + + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on TPU.") + return False + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa + + @classmethod + def use_all_gather(cls) -> bool: + return True + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + # V1 support on TPU is experimental + return True + + @classmethod + def validate_request( + cls, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, + ) -> None: + """Raises if this request is unsupported on this platform""" + if isinstance(params, SamplingParams): + if params.guided_decoding is not None and not envs.VLLM_USE_V1: + raise ValueError("Structured output is not supported on " + f"{cls.device_name} V0.") + if params.sampling_type == SamplingType.RANDOM_SEED: + raise ValueError( + "Torch XLA does not support per-request seed.") From 1c7103c067d83abbc4abcef0dc8d065c452c5f98 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 15 May 2025 18:32:41 +0000 Subject: [PATCH 13/15] Revert "import tpu communicators from tpu commons" This reverts commit b767f6f90959e07d175f3b3070ff793688180c45. Signed-off-by: Siyuan Liu --- .../device_communicators/tpu_communicator.py | 93 ++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 4f958652b586..de66ceaeef6f 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,4 +1,93 @@ # SPDX-License-Identifier: Apache-2.0 -from tpu_commons.distributed.device_communicators import ( # noqa: F401 - TpuCommunicator) +import os +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .base_device_communicator import DeviceCommunicatorBase + +USE_RAY = parallel_config = get_current_vllm_config( +).parallel_config.distributed_executor_backend == "ray" + +logger = init_logger(__name__) + +if current_platform.is_tpu(): + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) + + if USE_RAY: + from vllm.executor import ray_utils + + +class TpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node + # must be used together. Therefore, the local rank and world size can + # be simply calculated as follows. + global_rank = self.global_rank + global_world_size = self.global_world_size + + if USE_RAY: + logger.info("TpuCommunicator initialized with RAY") + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + + local_world_size = global_world_size // num_nodes + local_rank = global_rank % local_world_size + else: + logger.info("TpuCommunicator initialized with MP") + # Sanity: Verify we run on a single host + num_hosts = torch_xla.tpu.num_tpu_workers() + assert num_hosts == 1 + + # Get the current number of TPUs (we have locally) + local_world_size = torch_xla.tpu.num_available_chips() + + # Get current rank + local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + + pjrt.initialize_multiprocess(local_rank, local_world_size) + xr._init_world_size_ordinal() + self.groups = create_optimized_replica_groups() + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + # TODO: Remove the groups specification after XLA compiler can support + # auto-reordering the ring order for all-reduce. + return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(input_, dim=dim) From 20fe20de00245429295f67d4625bd7b77261ca00 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 15 May 2025 18:45:32 +0000 Subject: [PATCH 14/15] fall back for communicator and platform Signed-off-by: Siyuan Liu --- .../distributed/device_communicators/tpu_communicator.py | 9 +++++++++ vllm/platforms/tpu.py | 8 ++++++++ vllm/v1/worker/tpu_worker.py | 1 + 3 files changed, 18 insertions(+) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index de66ceaeef6f..a1775279661d 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -91,3 +91,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." return xm.all_gather(input_, dim=dim) + + +try: + from tpu_commons.distributed.device_communicators import ( + TpuCommunicator as TpuCommonsCommunicator) + TpuCommunicator = TpuCommonsCommunicator # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + pass diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d0a5af3587c4..d2bba9521dff 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -194,3 +194,11 @@ def validate_request( if params.sampling_type == SamplingType.RANDOM_SEED: raise ValueError( "Torch XLA does not support per-request seed.") + + +try: + from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform + TpuPlatform = TpuCommonsPlatform # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuPlatform") + pass diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ba6139bf992f..fa4eb30ccd9a 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -272,4 +272,5 @@ def init_tpu_worker_distributed_environment( from tpu_commons.worker import TPUWorker as TPUCommonsWorker TPUWorker = TPUCommonsWorker # type: ignore except ImportError: + logger.info("tpu_commons not found, using vLLM's TPUWorker.") pass From 206dad4a26b897b368064e1b4382987b5b349a35 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 16 May 2025 17:14:40 +0000 Subject: [PATCH 15/15] recover test script Signed-off-by: Siyuan Liu --- examples/offline_inference/tpu.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index e1d9f864c5f1..71cd88f2788a 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -20,13 +20,10 @@ def main(): # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. - llm = LLM( - model="Qwen/Qwen2-1.5B-Instruct", - max_num_batched_tokens=64, - max_num_seqs=4, - max_model_len=128, - enforce_eager=True, - ) + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + max_num_batched_tokens=64, + max_num_seqs=4, + max_model_len=128) outputs = llm.generate(prompts, sampling_params) print("-" * 50) for output, answer in zip(outputs, answers):