diff --git a/examples/offline_inference/offline_inference.py b/examples/offline_inference/offline_inference.py index 23cc6e853943..d71e3f373bd8 100644 --- a/examples/offline_inference/offline_inference.py +++ b/examples/offline_inference/offline_inference.py @@ -8,10 +8,10 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams() #temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=256, max_num_seqs=16, tensor_parallel_size=4) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -19,4 +19,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/llava_tpu.py b/llava_tpu.py new file mode 100644 index 000000000000..debc76dd8a68 --- /dev/null +++ b/llava_tpu.py @@ -0,0 +1,42 @@ +from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset + +# Initialize LLaVA model +llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + max_model_len=2048, + max_num_seqs=2, + dtype="bfloat16", +) + +# Load sample image +image = ImageAsset("cherry_blossom").pil_image.convert("RGB") + +# Create two different prompts +prompts = [ + "What do you see in this image?", + "What colors are most prominent in this image?", +] + +# Format prompts according to LLaVA's requirements +formatted_inputs = [ + { + "prompt": f"USER: \n{prompt}\nASSISTANT:", + "multi_modal_data": {"image": image} + } + for prompt in prompts +] + +# Set up sampling parameters +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=64, +) + +# Generate responses +outputs = llm.generate(formatted_inputs, sampling_params=sampling_params) + +# Print results +for i, output in enumerate(outputs): + print(f"\nPrompt {i + 1}: {prompts[i]}") + print(f"Response: {output.outputs[0].text}") diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 8ab18b3770ae..a94e191d3640 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -9,6 +9,7 @@ setuptools-scm>=8 wheel jinja2 ray[default] +ray[adag] # TODO: Remove this # Install torch_xla --pre diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b1d4461d164a..976c8f3473d1 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -66,14 +66,21 @@ def run_test(more_args): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="V1 currently only supported on CUDA") +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), + reason="V1 currently only supported on CUDA and TPU") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - run_test([]) + more_args = [] + + # Limit compilation time for V1 + if current_platform.is_tpu(): + more_args = ["--max-num-seqs", "64"] + + run_test(more_args) @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d26383970569..fa5517fd2ce4 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -163,6 +163,13 @@ def _cached_get_attn_backend( logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend return PallasAttentionBackend + elif backend == _Backend.PALLAS_VLLM_V1: + logger.info("Using Pallas backend.") + from vllm.v1.attention.backends.pallas import (PallasAttentionBackend + as + PallasAttentionBackendV1 + ) + return PallasAttentionBackendV1 elif backend == _Backend.NO_ATTENTION: from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionBackend) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ddccaa2ce014..37079a80411d 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -32,6 +32,7 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() + PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() NO_ATTENTION = enum.auto() diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 77f5c8401424..ce66706cc55a 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -11,6 +11,8 @@ else: VllmConfig = None +import vllm.envs as envs + logger = init_logger(__name__) @@ -23,9 +25,15 @@ class TpuPlatform(Platform): @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: - if selected_backend != _Backend.PALLAS: - logger.info("Cannot use %s backend on TPU.", selected_backend) - return _Backend.PALLAS + if envs.VLLM_USE_V1: + if selected_backend != _Backend.PALLAS_VLLM_V1: + logger.info("[V1] Cannot use %s backend on TPU.", + selected_backend) + return _Backend.PALLAS_VLLM_V1 + else: + if selected_backend != _Backend.PALLAS: + logger.info("Cannot use %s backend on TPU.", selected_backend) + return _Backend.PALLAS @classmethod def get_device_name(cls, device_id: int = 0) -> str: @@ -37,7 +45,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return True + return not envs.VLLM_USE_V1 @classmethod def inference_mode(cls): @@ -52,11 +60,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config.block_size = 16 compilation_config = vllm_config.compilation_config - if compilation_config.level == CompilationLevel.NO_COMPILATION: - # TPU does not support NO_COMPILATION + + # TPU only supports DYNAMO_ONCE compilation level + if (compilation_config.level == CompilationLevel.NO_COMPILATION + or compilation_config.level == CompilationLevel.PIECEWISE): + logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") compilation_config.level = CompilationLevel.DYNAMO_ONCE + assert compilation_config.level < CompilationLevel.PIECEWISE,\ - "TPU does not support Inductor." + ("Current compilation level = {} but needs to be less" + " than {}".format( + compilation_config.level, + CompilationLevel.PIECEWISE)) if compilation_config.backend == "": compilation_config.backend = "openxla" @@ -67,8 +82,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - if scheduler_config.is_multi_step: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + if envs.VLLM_USE_V1: + parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" else: - parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" + + # Adjust scheduler config for V1 + # TODO: Add support for these + if envs.VLLM_USE_V1: + if vllm_config.cache_config.enable_prefix_caching: + logger.info("[V1][TPU] Disable prefix caching") + vllm_config.cache_config.enable_prefix_caching = False + + if vllm_config.scheduler_config.chunked_prefill_enabled: + logger.info("[V1][TPU] Disable chunked prefill") + vllm_config.scheduler_config.chunked_prefill_enabled = False + + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on TPU.") + return False diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py new file mode 100644 index 000000000000..070c238acd63 --- /dev/null +++ b/vllm/v1/attention/backends/pallas.py @@ -0,0 +1,345 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + assert k_scale == 1.0 and v_scale == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + ) + return output diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index b26716f5c02e..eff3dbdfee90 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -205,6 +205,13 @@ def schedule(self) -> "SchedulerOutput": num_computed_tokens -= self.block_size num_new_tokens = self.block_size computed_blocks.pop() + + # If chunked prefill is not enabled, then breakout of the loop + # when above budget. + if (not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget): + break + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 975ce11fe8af..c95511862026 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -122,6 +122,7 @@ def step(self) -> List[EngineCoreOutput]: scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) + assert output is not None # TODO: Move into update_from_output below engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) return engine_core_outputs diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 5d74d4b01f50..3623904ce647 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Tuple, Type +from typing import Tuple, Type, Optional from vllm.config import VllmConfig from vllm.v1.outputs import ModelRunnerOutput @@ -41,7 +41,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: def execute_model( self, scheduler_output, - ) -> ModelRunnerOutput: + ) -> Optional[ModelRunnerOutput]: raise NotImplementedError @abstractmethod diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 41e6abbd6795..b25e537f3ff0 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -163,7 +163,7 @@ def collective_rpc(self, def execute_model( self, scheduler_output, - ) -> ModelRunnerOutput: + ) -> Optional[ModelRunnerOutput]: model_output = self.collective_rpc("execute_model", args=(scheduler_output, ))[0] return model_output diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 79acc60001c9..8af36909473c 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -7,6 +7,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.platforms import current_platform from vllm.v1.executor.abstract import Executor from vllm.v1.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, ray) @@ -27,6 +28,7 @@ def __init__(self, vllm_config: VllmConfig) -> None: self.vllm_config = vllm_config self.parallel_config = vllm_config.parallel_config self.model_config = vllm_config.model_config + self.forward_dag: Optional[ray.dag.CompiledDAG] = None # Disable Ray usage stats collection. @@ -34,6 +36,9 @@ def __init__(self, vllm_config: VllmConfig) -> None: if ray_usage != "1": os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + self.device_str = "TPU" if current_platform.is_tpu() else "GPU" + self.use_dag = current_platform.is_cuda() + initialize_ray_cluster(self.parallel_config) placement_group = self.parallel_config.placement_group @@ -42,16 +47,16 @@ def __init__(self, vllm_config: VllmConfig) -> None: def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): - # A list of workers to run a model. - self.workers: List[RayWorkerWrapper] = [] - if self.parallel_config.ray_workers_use_nsight: + if (current_platform.is_cuda() + and self.parallel_config.ray_workers_use_nsight): ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs) # Create the workers. + self.workers: List[RayWorkerWrapper] = [] driver_ip = get_ip() for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("GPU", 0): + if not bundle.get(self.device_str, 0): # Skip bundles that don't have GPUs, # as each worker needs one GPU. continue @@ -63,7 +68,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker = ray.remote( num_cpus=0, - num_gpus=1, + resources={self.device_str: 1}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) @@ -279,11 +284,14 @@ def execute_model( self, scheduler_output, ) -> ModelRunnerOutput: - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag() - # Only the first worker (with rank 0) returns the execution result. - # Others return None. - output = ray.get(self.forward_dag.execute(scheduler_output))[0] + if self.use_dag: + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag() + + output = ray.get(self.forward_dag.execute(scheduler_output))[0] + else: + output = self._run_workers("execute_model", scheduler_output)[0] + return output def profile(self, is_start=True): diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index be058318de58..1ef79108e98d 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -4,9 +4,16 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.platforms import current_platform from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.gpu_worker import Worker + +if current_platform.is_tpu(): + from vllm.v1.worker.tpu_worker import TPUWorker as WorkerClass # type: ignore +elif current_platform.is_cuda(): + from vllm.v1.worker.gpu_worker import GPUWorker as WorkerClass # type: ignore +else: + raise NotImplementedError("V1 executor supports CUDA or TPU") logger = init_logger(__name__) @@ -26,7 +33,8 @@ def __init__(self, vllm_config: VllmConfig) -> None: self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config - self.worker: Worker = self._create_worker() + self.worker: WorkerClass = self._create_worker() + self.worker.initialize() self.worker.load_model() @@ -34,15 +42,17 @@ def _create_worker( self, local_rank: int = 0, rank: int = 0, - distributed_init_method: Optional[str] = None) -> Worker: + distributed_init_method: Optional[str] = None) -> WorkerClass: """Return worker init args for a given rank.""" - # see https://github.com/NVIDIA/nccl/issues/1234 - os.environ['NCCL_CUMEM_ENABLE'] = '0' + if current_platform.is_cuda(): + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' if distributed_init_method is None: distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - return Worker( + + return WorkerClass( vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, @@ -68,7 +78,7 @@ def initialize(self, num_gpu_blocks: int) -> None: def execute_model( self, scheduler_output, - ) -> ModelRunnerOutput: + ) -> Optional[ModelRunnerOutput]: output = self.worker.execute_model(scheduler_output) return output diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 40494e64b22f..2d907bcc04d6 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -69,7 +69,7 @@ def __init__( self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) # Block table. self.block_table = BlockTable( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a1d4f9b13578..506f51016442 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional import numpy as np import torch @@ -23,60 +23,38 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch - +from vllm.v1.worker.model_runner_base import ModelRunnerBase, ExecutionMode if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) -class GPUModelRunner: +class GPUModelRunner(ModelRunnerBase): def __init__( self, vllm_config: VllmConfig, device: torch.device, ): - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - model_config = self.model_config - cache_config = self.cache_config - scheduler_config = self.scheduler_config - parallel_config = self.parallel_config - self.device = device - self.pin_memory = is_pin_memory_available() - self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] - - self.is_multimodal_model = model_config.is_multimodal_model - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens - self.max_num_reqs = scheduler_config.max_num_seqs - - # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() - self.hidden_size = model_config.get_hidden_size() + super().__init__(vllm_config, device) + + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + ) + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: List[torch.Tensor] = [] # Multi-modal data support self.input_registry = INPUT_REGISTRY @@ -90,24 +68,9 @@ def __init__( self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 self.encoder_cache_size = self.scheduler_config.encoder_cache_size - # Lazy initialization - # self.model: nn.Module # Set after load_model - self.kv_caches: List[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} - # Request states. - self.requests: Dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), - ) - self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) @@ -684,22 +647,23 @@ def load_model(self) -> None: self.model_memory_usage / float(2**30)) @torch.inference_mode() - def _dummy_run( + def dummy_run( self, - model: nn.Module, - num_tokens: int, - kv_caches: List[torch.Tensor], + kv_caches, + batch_size: int, + seq_len: Optional[int] = None, + exec_mode: Optional[ExecutionMode] = None, ) -> torch.Tensor: if self.is_multimodal_model: input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] + inputs_embeds = self.inputs_embeds[:batch_size] else: - input_ids = self.input_ids[:num_tokens] + input_ids = self.input_ids[:batch_size] inputs_embeds = None with set_forward_context(None, self.vllm_config): - hidden_states = model( + hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:num_tokens], + positions=self.positions[:batch_size], kv_caches=kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds, @@ -813,8 +777,7 @@ def profile_run(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) + hidden_states = self.dummy_run(dummy_kv_caches, self.max_num_tokens) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. @@ -840,8 +803,8 @@ def capture_model(self) -> None: for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(self.model, num_tokens, self.kv_caches) - self._dummy_run(self.model, num_tokens, self.kv_caches) + self.dummy_run(self.kv_caches, num_tokens) + self.dummy_run(self.kv_caches, num_tokens) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index af438f7d5820..1b01db4eb461 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,31 +1,27 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple import torch import torch.distributed -import vllm.envs as envs -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size + from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.worker_base import WorkerBase, get_cache_block_size, check_if_gpu_supports_dtype logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput - -class Worker: +class GPUWorker(WorkerBase): def __init__( self, @@ -34,82 +30,41 @@ def __init__( rank: int, distributed_init_method: str, ): - - # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None + super().__init__(vllm_config, local_rank, rank, + distributed_init_method) def initialize(self): - if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # This env var set by Ray causes exceptions with graph building. - os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) + assert self.device_config.device.type == "cuda" + + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] - _check_if_gpu_supports_dtype(self.model_config.dtype) - gc.collect() - torch.cuda.empty_cache() - self.init_gpu_memory = torch.cuda.mem_get_info()[0] - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + init_cuda_worker_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) # Construct the model runner self.model_runner = GPUModelRunner(self.vllm_config, self.device) - def load_model(self) -> None: - self.model_runner.load_model() - @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many @@ -123,6 +78,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ + assert self.model_runner is not None + # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -163,67 +120,31 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Calculate the number of blocks that can be allocated with the # profiled peak memory. - cache_block_size = _get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) + cache_block_size = get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) return num_gpu_blocks, 0 - def initialize_cache(self, num_gpu_blocks: int) -> None: - """Allocate GPU and CPU KV cache with the specified number of blocks.""" - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_gpu_blocks - max_model_len = self.model_config.max_model_len - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") - - self.model_runner.initialize_kv_cache(num_gpu_blocks) - - def compile_or_warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model() - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> ModelRunnerOutput: + ) -> Optional[ModelRunnerOutput]: + assert self.model_runner is not None output = self.model_runner.execute_model(scheduler_output) return output if self.rank == 0 else None - def profile(self, is_start: bool = True): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - if is_start: - self.profiler.start() - else: - self.profiler.stop() - - def check_health(self) -> None: - # worker will always be healthy as long as it's running. - return - -def init_worker_distributed_environment( +def init_cuda_worker_distributed_environment( parallel_config: ParallelConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, @@ -231,44 +152,3 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the" - "`dtype` flag in CLI, for example: --dtype=half.") - - -def _get_cache_block_size( - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, -) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - - key_cache_block = cache_config.block_size * num_heads * head_size - value_cache_block = key_cache_block - total = num_attention_layers * (key_cache_block + value_cache_block) - if cache_config.cache_dtype == "auto": - dtype = model_config.dtype - else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - dtype_size = get_dtype_size(dtype) - return dtype_size * total diff --git a/vllm/v1/worker/model_runner_base.py b/vllm/v1/worker/model_runner_base.py new file mode 100644 index 000000000000..3682f6ccf885 --- /dev/null +++ b/vllm/v1/worker/model_runner_base.py @@ -0,0 +1,104 @@ +from typing import TYPE_CHECKING, List, Optional + +import enum + +import torch +import torch.distributed +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, + is_pin_memory_available) +from vllm.v1.outputs import ModelRunnerOutput + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + + +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + +class ModelRunnerBase: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.device_config = vllm_config.device_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.is_multimodal_model = model_config.is_multimodal_model + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + + # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + raise NotImplementedError() + + def load_model(self) -> None: + raise NotImplementedError() + + def dummy_run( + self, + kv_caches, + batch_size: int, + seq_len: Optional[int] = None, + exec_mode: Optional[ExecutionMode] = None, + ) -> torch.Tensor: + raise NotImplementedError() + + def profile_run(self) -> None: + raise NotImplementedError() + + def capture_model(self) -> None: + raise NotImplementedError() + + def initialize_kv_cache(self, num_blocks: int) -> None: + raise NotImplementedError() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py new file mode 100644 index 000000000000..909e39c88f85 --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner.py @@ -0,0 +1,896 @@ +import gc +import time +import enum +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional +from unittest.mock import patch + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +# TPU XLA related +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.inputs import INPUT_REGISTRY +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.sampling_params import SamplingType +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, + is_pin_memory_available) +from vllm.v1.attention.backends.pallas import PallasMetadata, PallasAttentionBackend +from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.model_runner_base import ModelRunnerBase, ExecutionMode +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +@dataclass +class PrefillInputData: + + request_ids: List + prompt_lens: List + token_ids: List + position_ids: List + attn_metadata: List + + def zipped(self): + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata) + + +@dataclass +class DecodeInputData: + + num_decodes: int + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: Optional[PallasMetadata] = None + + +class TPUModelRunner(ModelRunnerBase): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(vllm_config, device) + + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + ) + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + + # Multi-modal data support + self.input_registry = INPUT_REGISTRY + self.mm_registry = MULTIMODAL_REGISTRY + + # NOTE: Initialized input mapper is only used for processing dummy + # multimodal data into multimodal kwargs for GPU memory profiling. + self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) + self.mm_input_mapper_profiling.use_cache = False + + self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 + self.encoder_cache_size = self.scheduler_config.encoder_cache_size + + # req_id -> (input_id -> encoder_output) + self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + + # Used to initialize positions for the individual prefills + self.prefill_positions = torch.tensor(range(self.max_model_len), + device="cpu", + dtype=torch.int32).reshape( + 1, -1) + + # Used to indicate how many prefills there are for each scheduler + # iteration + self.num_new_reqs: int = 0 + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for res_req_data in scheduler_output.scheduled_resumed_reqs: + req_id = res_req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = res_req_data.block_ids + req_state.num_computed_tokens = res_req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # For TPU, we keep all of the decode requests before the + # prefill requests in the batch sequence. + # 1. First condense, so all decodes move to start + # 2. Then add new prefills to the end of the batch + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state, None) # Append last + + self.num_new_reqs = len(req_ids_to_add) + + def _prepare_prefill_inputs( + self, + num_scheduled_tokens: List[int], + ) -> PrefillInputData: + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). + + prefill_request_ids = [] + prefill_prompt_lens = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = num_reqs - self.num_new_reqs + for idx in range(num_decodes, num_reqs): + req_id = self.input_batch.req_ids[idx] + prefill_request_ids.append(req_id) + + prompt_len = num_scheduled_tokens[idx] + prefill_prompt_lens.append(prompt_len) + + # STATIC SHAPE: prefills are padded to the next power of 2. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + assert padded_prompt_len <= self.max_model_len + + # TOKEN_IDS. + token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ + idx, :padded_prompt_len].reshape(1, -1)) + token_ids[:, prompt_len:] = 0 + prefill_token_ids.append(token_ids.to(self.device)) + + # POSITIONS. + positions = self.prefill_positions[:, :padded_prompt_len].clone() + positions[:, prompt_len:] = 0 + prefill_position_ids.append(positions.to(self.device)) + + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor( + ) + block_numbers = block_table_cpu_tensor[idx, positions // + self.block_size].reshape( + 1, -1) + + block_offsets = positions % self.block_size + slot_mapping = block_numbers * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[:, prompt_len:] = _PAD_SLOT_ID + slot_mapping = slot_mapping.long() + + # TODO: Remove prompt_len param here + prefill_attn_metadata.append( + PallasMetadata( + num_prefills=1, + num_prefill_tokens=prompt_len, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping.to(self.device), + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + effective_query_lens=None, + )) + + return PrefillInputData( + request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, + ) + + def _prepare_decode_inputs(self) -> DecodeInputData: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = num_reqs - self.num_new_reqs + + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + + # PAD FOR STATIC SHAPES. + padded_batch_size = _get_padded_batch_size(num_decodes) + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) + index = positions.to(torch.int64) + index[num_decodes:] = 0 + positions = positions[:padded_batch_size] + positions[num_decodes:] = 0 + + # TOKEN_IDS. [batch, 1] + token_ids = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), + dim=1, + index=index, + )[:padded_batch_size].to(torch.int32) + token_ids[num_decodes:] = 0 + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() + block_number = torch.gather(input=block_table_cpu_tensor, + dim=1, + index=(index // self.block_size)) + block_offsets = index % self.block_size + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + slot_mapping = slot_mapping.long() + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = block_table_cpu_tensor[:padded_batch_size] + + # CONTEXT_LENS [batch_size] + context_lens = (positions.reshape(-1) + 1) + context_lens[num_decodes:] = 0 + + # CPU<>TPU sync happens here. + return DecodeInputData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=padded_batch_size, + slot_mapping=slot_mapping.to(self.device), + multi_modal_placeholder_index_maps=None, + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + effective_query_lens=None, + )) + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + num_decodes = num_reqs - self.num_new_reqs + + # TODO: Resurrect + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + # TODO: Verify this works with TPUs + # self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + max_num_scheduled_tokens = 0 + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + + # NOTE: Assert that all the decodes are "decodes". + if idx < num_decodes: + assert num_tokens == 1 + + assert max_num_scheduled_tokens > 0 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens), + self._prepare_decode_inputs(), + ) + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + # TODO: Add sampler code to TPUs + raise NotImplementedError() + + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs: List[MultiModalKwargs] = [] + req_input_ids: List[Tuple[str, int]] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + for input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[input_id]) + req_input_ids.append((req_id, input_id)) + batched_mm_inputs = MultiModalKwargs.batch(mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `encoder_outputs` is either of the following: + # 1. A tensor of shape [num_images, feature_size, hidden_size] + # in case when feature_size is fixed across all images. + # 2. A list (length: num_images) of tensors, each of shape + # [feature_size, hidden_size] in case when the feature size is + # dynamic depending on input images. + encoder_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) + + # Cache the encoder outputs. + for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + self.encoder_cache[req_id][input_id] = output + + def _gather_encoder_outputs( + self, + scheduler_output: "SchedulerOutput", + ) -> List[torch.Tensor]: + encoder_outputs: List[torch.Tensor] = [] + num_reqs = self.input_batch.num_reqs + for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + encoder_outputs.append(encoder_output[start_idx:end_idx]) + return encoder_outputs + + @torch.no_grad() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_encoder(scheduler_output) + encoder_outputs = self._gather_encoder_outputs(scheduler_output) + else: + encoder_outputs = [] + + # Prepare the decoder inputs. + prefill_data, decode_data = self._prepare_inputs(scheduler_output) + + num_reqs = self.input_batch.num_reqs + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_batch, 1] + sampled_token_ids_list = [] + if decode_data.num_decodes > 0: + # FORWARD. + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] + sampled_token_ids_list.extend(token_ids.tolist()) + + # UPDATE REQUEST STATE. + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): + assert req_id is not None + req_state = self.requests[req_id] + + # TODO: ASSERT NO CHUNKED PREFILL. + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + # TODO: Verify if req_id_to_index mapping is needed here! + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + self.input_batch.num_tokens[i] += 1 + req_state.output_token_ids.append(token_id) + + ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. + for (req_id, prompt_len, token_ids, position_ids, + attn_metadata) in prefill_data.zipped(): + assert req_id is not None + + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = token_ids + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + # self.inputs_embeds.copy_(inputs_embeds) + # inputs_embeds = self.inputs_embeds + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = token_ids + inputs_embeds = None + + # FORWARD. + selected_token_ids = self.model(input_ids, position_ids, + attn_metadata, self.kv_caches, inputs_embeds) + + print("PREFILL selected_token_ids.shape = {}".format( + selected_token_ids.shape)) + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids_list.append(token_id) + req_state = self.requests[req_id] + + # TODO: ASSERT NO PREFIX CACHING. + assert req_state.num_computed_tokens == 0 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # TODO: ASSERT NO CHUNKED PREFILL. + assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + + # UPDATE REQUEST STATE. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + self.input_batch.num_tokens[req_idx] += 1 + req_state.output_token_ids.append(token_id) + + # num_reqs entries should be non-None + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=sampled_token_ids_list, + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + + return model_runner_output + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + model = ModelWrapperV1(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + + def dummy_run( + self, + kv_caches, + batch_size: int, + seq_len: Optional[int] = None, + exec_mode: Optional[ExecutionMode] = None, + ) -> None: + assert seq_len is not None + assert exec_mode is not None + + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = PallasMetadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + + else: + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + + effective_query_lens = torch.ones_like(context_lens) + + attn_metadata = PallasMetadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) + inputs_embeds = None + else: + assert seq_len == 1 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size * seq_len, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + ) + inputs_embeds = None + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if exec_mode.is_prefill(): + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + + self.model(token_ids, position_ids, attn_metadata, kv_caches, inputs_embeds) + + def profile_run(self) -> None: + raise NotImplementedError() + + def capture_model(self) -> None: + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + # Capture prefill shapes + start = time.perf_counter() + for batch_size in [1]: + seq_len = 16 + while True: + self.dummy_run(self.kv_caches, batch_size, seq_len, + ExecutionMode.PREFILL) + xm.wait_device_ops() + logger.info(" -- batch_size: %d, seq_len: %d", batch_size, + seq_len) + + if seq_len >= self.model_config.max_model_len: + break + + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + + # Move to next seq_len + seq_len = seq_len * 2 + + end = time.perf_counter() + logger.info("Compilation for prefill shapes is done in %.2f [secs].", + end - start) + + # Capture decode shapes. + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self.dummy_run(self.kv_caches, batch_size, seq_len, + ExecutionMode.DECODE) + xm.wait_device_ops() + logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d", + batch_size, seq_len, + self.scheduler_config.max_num_seqs) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode shapes is done in %.2f [secs].", + end - start) + + def initialize_kv_cache(self, num_tpu_blocks: int) -> None: + assert len(self.kv_caches) == 0 + + tpu_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_tpu_blocks, self.block_size, self.num_kv_heads, self.head_size) + + for _ in range(self.num_attn_layers): + tpu_k_cache = torch.zeros(tpu_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + self.kv_caches.append((tpu_k_cache, tpu_v_cache)) + + +class ModelWrapperV1(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + input_ids=token_ids, + positions=position_ids, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, None) + + # Greedy sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + return argmax_token_ids + + def get_multimodal_embeddings(self, *args, **kwargs): + return self.model.get_multimodal_embeddings(*args, **kwargs) + + def get_input_embeddings(self, *args, **kwargs): + return self.model.get_input_embeddings(*args, **kwargs) + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py new file mode 100644 index 000000000000..8bbdf00794e6 --- /dev/null +++ b/vllm/v1/worker/tpu_worker.py @@ -0,0 +1,151 @@ +"""A GPU worker class.""" +import os +from typing import Optional, Tuple + +import torch +import torch.distributed +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import get_dtype_size + +from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_model_runner import TPUModelRunner, ExecutionMode +from vllm.v1.worker.worker_base import WorkerBase + +logger = init_logger(__name__) + + +class TPUWorker(WorkerBase): + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + ): + super().__init__(vllm_config, local_rank, rank, + distributed_init_method) + + def initialize(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # Initialize the distributed environment. + init_tpu_worker_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + self.local_rank) + + # Device initialization should happen after initializing + # the distributed runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config, self.device) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + assert self.model_runner is not None + + num_layers = self.model_config.get_num_layers(self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [(torch.tensor([], dtype=torch.float32, + device=self.device), + torch.tensor([], dtype=torch.float32, + device=self.device)) + for _ in range(num_layers)] + self.model_runner.dummy_run( + batch_size=1, + seq_len=self.scheduler_config.max_num_batched_tokens, + exec_mode=ExecutionMode.PREFILL, + kv_caches=kv_caches, + ) + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + + # Calculate the TPU KV cache size based on profiling. + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + dtype_btyes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_btyes * self.cache_config.block_size * + num_layers * 2 * head_size * num_kv_heads) + num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes + num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. + + # Calculate the CPU KV cache size based on the config. + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + block_size_bytes) + num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. + return num_tpu_blocks, num_cpu_blocks + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + assert self.model_runner is not None + output = self.model_runner.execute_model(scheduler_output) + return output if self.rank == 0 else None + + +def init_tpu_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + local_rank=local_rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py new file mode 100644 index 000000000000..beea0bf12e00 --- /dev/null +++ b/vllm/v1/worker/worker_base.py @@ -0,0 +1,180 @@ +"""A GPU worker class.""" +from typing import TYPE_CHECKING, Optional, Tuple, Any + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size +from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.model_runner_base import ModelRunnerBase + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + + +class WorkerBase: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + # Initialized by the specific platform + self.model_runner: Optional[ModelRunnerBase] = None + + def initialize(self): + raise NotImplementedError() + + def load_model(self) -> None: + assert self.model_runner is not None + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + raise NotImplementedError() + + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks.""" + assert self.model_runner is not None + + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_gpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.model_runner.initialize_kv_cache(num_gpu_blocks) + + def compile_or_warm_up_model(self) -> None: + assert self.model_runner is not None + + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + raise NotImplementedError() + + def profile(self, is_start: bool = True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + + +def check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() + gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") + + +def get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total