diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 92de545acd53..c6deb1c538cf 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -351,7 +351,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_metadata_list.append(seq_group_metadata) (input_tokens, input_positions, attn_metadata, _, _, _, - _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + _) = model_runner.prepare_modelrunner_input(seq_group_metadata_list) prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index eaf43247d4fc..ee10f2e1ce4f 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -270,7 +270,7 @@ def _prepare_decode( attn_metadata, ) - def prepare_input_tensors( + def prepare_modelrunner_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, @@ -334,7 +334,7 @@ def execute_model( ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) + ) = self.prepare_modelrunner_input(seq_group_metadata_list) model_executable = self.model execute_model_kwargs = { diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 465130d10e2f..f6fcc07f768c 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple import torch @@ -6,7 +6,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -18,6 +17,16 @@ logger = init_logger(__name__) +class ModelRunnerInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: AttentionMetadata + pooling_metadata: PoolingMetadata + lora_requests: Set[LoRARequest] + lora_mapping: Optional[LoRAMapping] + multi_modal_input: Dict[str, torch.Tensor] + + class EmbeddingModelRunner(ModelRunner): def __init__( @@ -47,12 +56,11 @@ def __init__( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + modelrunner_input: ModelRunnerInput, kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) + lora_requests, lora_mapping, multi_modal_input) = modelrunner_input if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) @@ -86,64 +94,66 @@ def execute_model( return self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata) - def prepare_input_tensors( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: - if self.is_driver_worker: - assert seq_group_metadata_list is not None - # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - _, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) - # Prepare PoolingMetadata - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - seq_lens) - - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "lora_requests": lora_requests, - "lora_mapping": lora_mapping, - "multi_modal_kwargs": multi_modal_kwargs, - "num_prefill_tokens": num_prefill_tokens, - "num_decode_tokens": num_decode_tokens, - "slot_mapping": slot_mapping, - "num_prefills": num_prefills, - } - if attn_metadata: - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) + def prepare_inputs_to_broadcast( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> Tuple[Dict[str, Any], Optional[List[Any]]]: + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + _, + lora_mapping, + lora_requests, + multi_modal_kwargs, + slot_mapping, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) + # Prepare PoolingMetadata + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + seq_lens) + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_kwargs": multi_modal_kwargs, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + } + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) + + return metadata_dict, [pooling_metadata] + + def convert_broadcast_inputs_to_modelrunner_input( + self, + metadata_dict: Dict[str, Any], + aux: Optional[List[Any]] = None): + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + else: + attn_metadata = None + if aux is not None: + pooling_metadata = aux[0] else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - lora_mapping = metadata_dict.pop("lora_mapping") - lora_requests = metadata_dict.pop("lora_requests") - multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - if metadata_dict: - attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - else: - attn_metadata = None pooling_metadata = PoolingMetadata(seq_groups=None, seq_data=None, prompt_lens=None) - - return (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_kwargs) + return ModelRunnerInput(input_tokens, input_positions, attn_metadata, + pooling_metadata, lora_requests, lora_mapping, + multi_modal_kwargs) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7879a5de5b7b..0896bee9e2ee 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -2,7 +2,7 @@ import time import warnings from collections import defaultdict -from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch @@ -12,7 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.distributed.communication_op import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -71,6 +70,16 @@ def empty(cls, device): ) +class ModelRunnerInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: AttentionMetadata + sampling_metadata: SamplingMetadata + lora_requests: Set[LoRARequest] + lora_mapping: Optional[LoRAMapping] + multi_modal_kwargs: Dict[str, torch.Tensor] + + class ModelRunner: def __init__( @@ -646,82 +655,94 @@ def _prepare_model_input( num_prefills=num_prefills, ) - def prepare_input_tensors( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: - if self.is_driver_worker: - assert seq_group_metadata_list is not None - # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - query_lens, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, self.device, - self.pin_memory) - - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "selected_token_indices": - sampling_metadata.selected_token_indices, - "lora_requests": lora_requests, - "lora_mapping": lora_mapping, - "multi_modal_kwargs": multi_modal_kwargs, - "num_prefill_tokens": num_prefill_tokens, - "num_decode_tokens": num_decode_tokens, - "slot_mapping": slot_mapping, - "num_prefills": num_prefills, - } - if attn_metadata: - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) + def prepare_inputs_to_broadcast( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> Tuple[Dict, Optional[List]]: + """ + This function prepares the input tensors for the model. The output + is a dictionary that will be broadcasted to all workers, and a list + of auxiliary data that will be passed to the model execution function + for the driver worker. + """ + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + lora_mapping, + lora_requests, + multi_modal_kwargs, + slot_mapping, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + seq_lens, query_lens, + self.device, + self.pin_memory) + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": sampling_metadata.selected_token_indices, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_kwargs": multi_modal_kwargs, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + } + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) + + return metadata_dict, [sampling_metadata] + + def convert_broadcast_inputs_to_modelrunner_input( + self, metadata_dict: Dict[str, Any], aux: Optional[List[Any]]): + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop("selected_token_indices") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + else: + attn_metadata = None + + if aux is not None: + sampling_metadata = aux[0] else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - selected_token_indices = metadata_dict.pop( - "selected_token_indices") - lora_mapping = metadata_dict.pop("lora_mapping") - lora_requests = metadata_dict.pop("lora_requests") - multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - if metadata_dict: - attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - else: - attn_metadata = None sampling_metadata = SamplingMetadata( seq_groups=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, num_prompts=0, ) + return ModelRunnerInput(input_tokens, input_positions, attn_metadata, + sampling_metadata, lora_requests, lora_mapping, + multi_modal_kwargs) - return (input_tokens, input_positions, attn_metadata, - sampling_metadata, lora_requests, lora_mapping, - multi_modal_kwargs) + def prepare_modelrunner_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ): + assert seq_group_metadata_list is not None + return self.convert_broadcast_inputs_to_modelrunner_input( + *self.prepare_inputs_to_broadcast(seq_group_metadata_list)) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + modelrunner_input: ModelRunnerInput, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_kwargs - ) = self.prepare_input_tensors(seq_group_metadata_list) + lora_requests, lora_mapping, multi_modal_kwargs) = modelrunner_input if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) @@ -830,7 +851,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - self.execute_model(seqs, kv_caches) + self.execute_model(self.prepare_modelrunner_input(seqs), kv_caches) torch.cuda.synchronize() return diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a336be04e124..237882a18f8f 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -139,7 +139,7 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - def prepare_input_tensors( + def prepare_modelrunner_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: @@ -173,7 +173,7 @@ def execute_model( seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, input_block_ids, sampling_metadata - ) = self.prepare_input_tensors(seq_group_metadata_list) + ) = self.prepare_modelrunner_input(seq_group_metadata_list) hidden_states = self.model( input_ids=input_tokens, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 10411a2bf7a1..3d621a723156 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -261,6 +261,11 @@ def execute_model( "blocks_to_swap_out": blocks_to_swap_out, "blocks_to_copy": blocks_to_copy, } + + if num_seq_groups != 0: + inputs_to_broadcast, aux = self.model_runner.prepare_inputs_to_broadcast( # noqa + seq_group_metadata_list) + data.update(inputs_to_broadcast) broadcast_tensor_dict(data, src=0) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) @@ -269,7 +274,10 @@ def execute_model( if num_seq_groups == 0: return [] - output = self.model_runner.execute_model(seq_group_metadata_list, + modelrunner_input = ( + self.model_runner.convert_broadcast_inputs_to_modelrunner_input( + inputs_to_broadcast, aux)) + output = self.model_runner.execute_model(modelrunner_input, self.gpu_cache) # Worker only supports single-step execution. Wrap the output in a list @@ -296,17 +304,21 @@ def _execute_model_non_driver(self) -> bool: if not data: return False - num_seq_groups = data.get("num_seq_groups", 0) - blocks_to_swap_in = data.get("blocks_to_swap_in") - blocks_to_swap_out = data.get("blocks_to_swap_out") - blocks_to_copy = data.get("blocks_to_copy") + num_seq_groups = data.pop("num_seq_groups") + blocks_to_swap_in = data.pop("blocks_to_swap_in") + blocks_to_swap_out = data.pop("blocks_to_swap_out") + blocks_to_copy = data.pop("blocks_to_copy") self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: return False - self.model_runner.execute_model(None, self.gpu_cache) + modelrunner_input = ( + self.model_runner.convert_broadcast_inputs_to_modelrunner_input( + data, None)) + + self.model_runner.execute_model(modelrunner_input, self.gpu_cache) return True def add_lora(self, lora_request: LoRARequest) -> bool: