-
-
Notifications
You must be signed in to change notification settings - Fork 16.2k
[Core][Distributed] merge two broadcast_tensor_dict #5354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a191485
313388b
8bee1f1
7a7f535
669bcb3
88c6e90
2164764
6c230e4
0a48915
3a1df09
72a1a4b
665fb0d
07a3fe4
ebf0c4c
fe6d90d
1279cb5
7a3d62d
6c59994
5bfda47
be107d9
cb691ad
ebc7bbf
87f68ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
| import time | ||
| import warnings | ||
| from collections import defaultdict | ||
| from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union | ||
| from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -12,7 +12,6 @@ | |
| from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, | ||
| ModelConfig, ParallelConfig, SchedulerConfig, | ||
| VisionLanguageConfig) | ||
| from vllm.distributed import broadcast_tensor_dict | ||
| from vllm.distributed.communication_op import graph_capture | ||
| from vllm.logger import init_logger | ||
| from vllm.lora.layers import LoRAMapping | ||
|
|
@@ -71,6 +70,16 @@ def empty(cls, device): | |
| ) | ||
|
|
||
|
|
||
| class ModelRunnerInput(NamedTuple): | ||
| input_tokens: torch.Tensor | ||
| input_positions: torch.Tensor | ||
| attn_metadata: AttentionMetadata | ||
| sampling_metadata: SamplingMetadata | ||
| lora_requests: Set[LoRARequest] | ||
| lora_mapping: Optional[LoRAMapping] | ||
| multi_modal_kwargs: Dict[str, torch.Tensor] | ||
|
|
||
|
|
||
| class ModelRunner: | ||
|
|
||
| def __init__( | ||
|
|
@@ -646,82 +655,94 @@ def _prepare_model_input( | |
| num_prefills=num_prefills, | ||
| ) | ||
|
|
||
| def prepare_input_tensors( | ||
| self, | ||
| seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, | ||
| Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: | ||
| if self.is_driver_worker: | ||
| assert seq_group_metadata_list is not None | ||
| # Prepare input tensors. | ||
| ( | ||
| input_tokens, | ||
| input_positions, | ||
| attn_metadata, | ||
| seq_lens, | ||
| query_lens, | ||
| lora_mapping, | ||
| lora_requests, | ||
| multi_modal_kwargs, | ||
| slot_mapping, | ||
| num_prefill_tokens, | ||
| num_decode_tokens, | ||
| num_prefills, | ||
| ) = self._prepare_model_input(seq_group_metadata_list) | ||
| sampling_metadata = SamplingMetadata.prepare( | ||
| seq_group_metadata_list, seq_lens, query_lens, self.device, | ||
| self.pin_memory) | ||
|
|
||
| metadata_dict = { | ||
| "input_tokens": input_tokens, | ||
| "input_positions": input_positions, | ||
| "selected_token_indices": | ||
| sampling_metadata.selected_token_indices, | ||
| "lora_requests": lora_requests, | ||
| "lora_mapping": lora_mapping, | ||
| "multi_modal_kwargs": multi_modal_kwargs, | ||
| "num_prefill_tokens": num_prefill_tokens, | ||
| "num_decode_tokens": num_decode_tokens, | ||
| "slot_mapping": slot_mapping, | ||
| "num_prefills": num_prefills, | ||
| } | ||
| if attn_metadata: | ||
| metadata_dict.update(attn_metadata.asdict_zerocopy()) | ||
| broadcast_tensor_dict(metadata_dict, src=0) | ||
| def prepare_inputs_to_broadcast( | ||
| self, seq_group_metadata_list: List[SequenceGroupMetadata] | ||
| ) -> Tuple[Dict, Optional[List]]: | ||
| """ | ||
| This function prepares the input tensors for the model. The output | ||
| is a dictionary that will be broadcasted to all workers, and a list | ||
| of auxiliary data that will be passed to the model execution function | ||
| for the driver worker. | ||
| """ | ||
| # Prepare input tensors. | ||
| ( | ||
| input_tokens, | ||
| input_positions, | ||
| attn_metadata, | ||
| seq_lens, | ||
| query_lens, | ||
| lora_mapping, | ||
| lora_requests, | ||
| multi_modal_kwargs, | ||
| slot_mapping, | ||
| num_prefill_tokens, | ||
| num_decode_tokens, | ||
| num_prefills, | ||
| ) = self._prepare_model_input(seq_group_metadata_list) | ||
| sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, | ||
| seq_lens, query_lens, | ||
| self.device, | ||
| self.pin_memory) | ||
|
|
||
| metadata_dict = { | ||
| "input_tokens": input_tokens, | ||
| "input_positions": input_positions, | ||
| "selected_token_indices": sampling_metadata.selected_token_indices, | ||
| "lora_requests": lora_requests, | ||
| "lora_mapping": lora_mapping, | ||
| "multi_modal_kwargs": multi_modal_kwargs, | ||
| "num_prefill_tokens": num_prefill_tokens, | ||
| "num_decode_tokens": num_decode_tokens, | ||
| "slot_mapping": slot_mapping, | ||
| "num_prefills": num_prefills, | ||
| } | ||
| if attn_metadata: | ||
| metadata_dict.update(attn_metadata.asdict_zerocopy()) | ||
|
|
||
| return metadata_dict, [sampling_metadata] | ||
|
|
||
| def convert_broadcast_inputs_to_modelrunner_input( | ||
| self, metadata_dict: Dict[str, Any], aux: Optional[List[Any]]): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add return typing?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I delete it to avoid typing conflict. This function Suggestions are welcome for fixing it while making mypy happy. |
||
| input_tokens = metadata_dict.pop("input_tokens") | ||
| input_positions = metadata_dict.pop("input_positions") | ||
| selected_token_indices = metadata_dict.pop("selected_token_indices") | ||
| lora_mapping = metadata_dict.pop("lora_mapping") | ||
| lora_requests = metadata_dict.pop("lora_requests") | ||
| multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") | ||
| if metadata_dict: | ||
| attn_metadata = self.attn_backend.make_metadata(**metadata_dict) | ||
| else: | ||
| attn_metadata = None | ||
|
|
||
| if aux is not None: | ||
| sampling_metadata = aux[0] | ||
| else: | ||
| metadata_dict = broadcast_tensor_dict(src=0) | ||
| input_tokens = metadata_dict.pop("input_tokens") | ||
| input_positions = metadata_dict.pop("input_positions") | ||
| selected_token_indices = metadata_dict.pop( | ||
| "selected_token_indices") | ||
| lora_mapping = metadata_dict.pop("lora_mapping") | ||
| lora_requests = metadata_dict.pop("lora_requests") | ||
| multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") | ||
| if metadata_dict: | ||
| attn_metadata = self.attn_backend.make_metadata( | ||
| **metadata_dict) | ||
| else: | ||
| attn_metadata = None | ||
| sampling_metadata = SamplingMetadata( | ||
| seq_groups=None, | ||
| selected_token_indices=selected_token_indices, | ||
| categorized_sample_indices=None, | ||
| num_prompts=0, | ||
| ) | ||
| return ModelRunnerInput(input_tokens, input_positions, attn_metadata, | ||
| sampling_metadata, lora_requests, lora_mapping, | ||
| multi_modal_kwargs) | ||
|
|
||
| return (input_tokens, input_positions, attn_metadata, | ||
| sampling_metadata, lora_requests, lora_mapping, | ||
| multi_modal_kwargs) | ||
| def prepare_modelrunner_input( | ||
| self, | ||
| seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], | ||
| ): | ||
| assert seq_group_metadata_list is not None | ||
| return self.convert_broadcast_inputs_to_modelrunner_input( | ||
| *self.prepare_inputs_to_broadcast(seq_group_metadata_list)) | ||
|
|
||
| @torch.inference_mode() | ||
| def execute_model( | ||
| self, | ||
| seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], | ||
| modelrunner_input: ModelRunnerInput, | ||
| kv_caches: List[torch.Tensor], | ||
| ) -> Optional[SamplerOutput]: | ||
| (input_tokens, input_positions, attn_metadata, sampling_metadata, | ||
| lora_requests, lora_mapping, multi_modal_kwargs | ||
| ) = self.prepare_input_tensors(seq_group_metadata_list) | ||
| lora_requests, lora_mapping, multi_modal_kwargs) = modelrunner_input | ||
|
|
||
| if self.lora_config: | ||
| self.set_active_loras(lora_requests, lora_mapping) | ||
|
|
@@ -830,7 +851,7 @@ def profile_run(self) -> None: | |
| # Run the model with the dummy inputs. | ||
| num_layers = self.model_config.get_num_layers(self.parallel_config) | ||
| kv_caches = [None] * num_layers | ||
| self.execute_model(seqs, kv_caches) | ||
| self.execute_model(self.prepare_modelrunner_input(seqs), kv_caches) | ||
| torch.cuda.synchronize() | ||
| return | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether we should have this return
ModelRunnerInputand then add a method to thatget_dict_to_broadcast()which you only call in the TP > 1 case. This would be more efficient for the non-TP case.Then this method could be called
prepare_modelrunner_input()which would make more sense to me since it's used even for non-TP where broadcasting isn't being done.That would also obviate the need for this separate
auxarg/variable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is who takes the responsibility to drive the broadcast process and how does it know the function to call, given that we have inheritance in
model_runner.pyandembedding_model_runner.py.Previously, each model runner drives the broadcast operation itself, inside
execute_model, which leads to a separate broadcast.In this PR, worker drives the broadcast operation, so it needs
ModelRunnerto returninputs_to_broadcast, broadcast it, and then feed it back toModelRunner.If we use
ModelRunnerInput. get_dict_to_broadcast, how can the worker know whichModelRunnerInputto use? We have multipleModelRunnerInput, used in bothmodel_runner.pyandembedding_model_runner.py.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You kind of answered your own question :) ... both
ModelRunnerInputs implement that and the worker just calls it (could be a protocol)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be solved by adding an abstract class and let ModelRunnerInput inherit the abstract class. For reasons I don't know, inheritance is discouraged in vllm, and several modelrunners just duplicate the code.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inheritance is used in various places already. IMO it makes sense to use judiciously but to avoid overdoing it.
In any case that's why I suggested to use a protocol here, you can do the same thing without inheritance.