Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
decode_metadata_list.append(seq_group_metadata)

(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
_) = model_runner.prepare_modelrunner_input(seq_group_metadata_list)

prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _prepare_decode(
attn_metadata,
)

def prepare_input_tensors(
def prepare_modelrunner_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Expand Down Expand Up @@ -334,7 +334,7 @@ def execute_model(
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
) = self.prepare_modelrunner_input(seq_group_metadata_list)

model_executable = self.model
execute_model_kwargs = {
Expand Down
128 changes: 69 additions & 59 deletions vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple

import torch

from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
Expand All @@ -18,6 +17,16 @@
logger = init_logger(__name__)


class ModelRunnerInput(NamedTuple):
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: AttentionMetadata
pooling_metadata: PoolingMetadata
lora_requests: Set[LoRARequest]
lora_mapping: Optional[LoRAMapping]
multi_modal_input: Dict[str, torch.Tensor]


class EmbeddingModelRunner(ModelRunner):

def __init__(
Expand Down Expand Up @@ -47,12 +56,11 @@ def __init__(
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
modelrunner_input: ModelRunnerInput,
kv_caches: List[torch.Tensor],
) -> Optional[PoolerOutput]:
(input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
lora_requests, lora_mapping, multi_modal_input) = modelrunner_input

if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
Expand Down Expand Up @@ -86,64 +94,66 @@ def execute_model(
return self.model.pooler(hidden_states=hidden_states,
pooling_metadata=pooling_metadata)

def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
_,
lora_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
# Prepare PoolingMetadata
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
seq_lens)

metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
def prepare_inputs_to_broadcast(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[Dict[str, Any], Optional[List[Any]]]:
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
_,
lora_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
# Prepare PoolingMetadata
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
seq_lens)

metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())

return metadata_dict, [pooling_metadata]

def convert_broadcast_inputs_to_modelrunner_input(
self,
metadata_dict: Dict[str, Any],
aux: Optional[List[Any]] = None):
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
else:
attn_metadata = None
if aux is not None:
pooling_metadata = aux[0]
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
attn_metadata = None
pooling_metadata = PoolingMetadata(seq_groups=None,
seq_data=None,
prompt_lens=None)

return (input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_kwargs)
return ModelRunnerInput(input_tokens, input_positions, attn_metadata,
pooling_metadata, lora_requests, lora_mapping,
multi_modal_kwargs)

def _prepare_pooling(
self,
Expand Down
149 changes: 85 additions & 64 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import warnings
from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand All @@ -12,7 +12,6 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.communication_op import graph_capture
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
Expand Down Expand Up @@ -71,6 +70,16 @@ def empty(cls, device):
)


class ModelRunnerInput(NamedTuple):
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: AttentionMetadata
sampling_metadata: SamplingMetadata
lora_requests: Set[LoRARequest]
lora_mapping: Optional[LoRAMapping]
multi_modal_kwargs: Dict[str, torch.Tensor]


class ModelRunner:

def __init__(
Expand Down Expand Up @@ -646,82 +655,94 @@ def _prepare_model_input(
num_prefills=num_prefills,
)

def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
lora_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)

metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
def prepare_inputs_to_broadcast(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we should have this return ModelRunnerInput and then add a method to that get_dict_to_broadcast() which you only call in the TP > 1 case. This would be more efficient for the non-TP case.

Then this method could be called prepare_modelrunner_input() which would make more sense to me since it's used even for non-TP where broadcasting isn't being done.

That would also obviate the need for this separate aux arg/variable.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is who takes the responsibility to drive the broadcast process and how does it know the function to call, given that we have inheritance in model_runner.py and embedding_model_runner.py .

Previously, each model runner drives the broadcast operation itself, inside execute_model, which leads to a separate broadcast.

In this PR, worker drives the broadcast operation, so it needs ModelRunner to return inputs_to_broadcast, broadcast it, and then feed it back to ModelRunner.

If we use ModelRunnerInput. get_dict_to_broadcast , how can the worker know which ModelRunnerInput to use? We have multiple ModelRunnerInput, used in both model_runner.py and embedding_model_runner.py.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use ModelRunnerInput. get_dict_to_broadcast , how can the worker know which ModelRunnerInput to use? We have multiple ModelRunnerInput, used in both model_runner.py and embedding_model_runner.py.

You kind of answered your own question :) ... both ModelRunnerInputs implement that and the worker just calls it (could be a protocol)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be solved by adding an abstract class and let ModelRunnerInput inherit the abstract class. For reasons I don't know, inheritance is discouraged in vllm, and several modelrunners just duplicate the code.

Copy link
Copy Markdown
Member

@njhill njhill Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inheritance is used in various places already. IMO it makes sense to use judiciously but to avoid overdoing it.

In any case that's why I suggested to use a protocol here, you can do the same thing without inheritance.

self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[Dict, Optional[List]]:
"""
This function prepares the input tensors for the model. The output
is a dictionary that will be broadcasted to all workers, and a list
of auxiliary data that will be passed to the model execution function
for the driver worker.
"""
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
lora_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
seq_lens, query_lens,
self.device,
self.pin_memory)

metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices": sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())

return metadata_dict, [sampling_metadata]

def convert_broadcast_inputs_to_modelrunner_input(
self, metadata_dict: Dict[str, Any], aux: Optional[List[Any]]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add return typing?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I delete it to avoid typing conflict. This function convert_broadcast_inputs_to_modelrunner_input is implemented in embedding_model_runner.py as well, but the return types differ (although both are named ModelRunnerInput, they live in different modules, so mypy complains about it).

Suggestions are welcome for fixing it while making mypy happy.

input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop("selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
else:
attn_metadata = None

if aux is not None:
sampling_metadata = aux[0]
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
attn_metadata = None
sampling_metadata = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return ModelRunnerInput(input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_kwargs)

return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_kwargs)
def prepare_modelrunner_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
):
assert seq_group_metadata_list is not None
return self.convert_broadcast_inputs_to_modelrunner_input(
*self.prepare_inputs_to_broadcast(seq_group_metadata_list))

@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
modelrunner_input: ModelRunnerInput,
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_kwargs
) = self.prepare_input_tensors(seq_group_metadata_list)
lora_requests, lora_mapping, multi_modal_kwargs) = modelrunner_input

if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
Expand Down Expand Up @@ -830,7 +851,7 @@ def profile_run(self) -> None:
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
self.execute_model(seqs, kv_caches)
self.execute_model(self.prepare_modelrunner_input(seqs), kv_caches)
torch.cuda.synchronize()
return

Expand Down
Loading