diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index c26fe9d6793d..30ab27d190ae 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -21,6 +21,7 @@ import gc import time from copy import deepcopy +from typing import Any, NamedTuple import numpy as np import torch @@ -44,7 +45,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput from vllm.v1.worker.gpu.attn_utils import ( @@ -213,7 +214,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.pooling_runner: PoolingRunner | None = None # For transferring state from execute_model to subsequent sample_tokens call. - self.execute_model_state: tuple | None = None + self.execute_model_state: ExecuteModelState | None = None def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len @@ -375,16 +376,12 @@ def _dummy_run( return None, None assert self.execute_model_state is not None - ( - input_batch, - model_inputs, - attn_metadata, - slot_mappings_by_layer, - hidden_states, - aux_hidden_states, - kv_connector_output, - num_tokens_across_dp, - ) = self.execute_model_state + input_batch = self.execute_model_state.input_batch + attn_metadata = self.execute_model_state.attn_metadata + slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer + hidden_states = self.execute_model_state.hidden_states + aux_hidden_states = self.execute_model_state.aux_hidden_states + num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp self.execute_model_state = None # dummy run the eagle speculator's propose to ensure DP/EP sync. @@ -989,15 +986,14 @@ def execute_model( aux_hidden_states = None kv_connector_output = self.kv_connector.post_forward(scheduler_output) - self.execute_model_state = ( - input_batch, - model_inputs, - attn_metadata, - slot_mappings_by_layer, - hidden_states, - aux_hidden_states, - kv_connector_output, - num_tokens_across_dp, + self.execute_model_state = ExecuteModelState( + input_batch=input_batch, + attn_metadata=attn_metadata, + slot_mappings_by_layer=slot_mappings_by_layer, + hidden_states=hidden_states, + aux_hidden_states=aux_hidden_states, + kv_connector_output=kv_connector_output, + num_tokens_across_dp=num_tokens_across_dp, ) if not self.is_last_pp_rank: @@ -1016,16 +1012,14 @@ def sample_tokens( if self.execute_model_state is None: # The prior execute_model call must have failed. return None - ( - input_batch, - model_inputs, - attn_metadata, - slot_mappings_by_layer, - hidden_states, - aux_hidden_states, - kv_connector_output, - num_tokens_across_dp, - ) = self.execute_model_state + + input_batch = self.execute_model_state.input_batch + attn_metadata = self.execute_model_state.attn_metadata + slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer + hidden_states = self.execute_model_state.hidden_states + aux_hidden_states = self.execute_model_state.aux_hidden_states + kv_connector_output = self.execute_model_state.kv_connector_output + num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp self.execute_model_state = None if not self.is_last_pp_rank: @@ -1116,9 +1110,9 @@ def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None: # The prior execute_model call must have failed. return None - input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = ( - self.execute_model_state - ) + input_batch = self.execute_model_state.input_batch + hidden_states = self.execute_model_state.hidden_states + kv_connector_output = self.execute_model_state.kv_connector_output self.execute_model_state = None if not self.is_last_pp_rank: @@ -1164,3 +1158,13 @@ def postprocess_pool(self, input_batch: InputBatch) -> None: np.minimum( computed_prefill, self.req_states.prefill_len.np, out=computed_prefill ) + + +class ExecuteModelState(NamedTuple): + input_batch: InputBatch + attn_metadata: dict[str, Any] | None + slot_mappings_by_layer: dict[str, torch.Tensor] | None + hidden_states: torch.Tensor | IntermediateTensors + aux_hidden_states: list[torch.Tensor] | None + kv_connector_output: KVConnectorOutput | None + num_tokens_across_dp: torch.Tensor | None