Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import gc
import time
from copy import deepcopy
from typing import Any, NamedTuple

import numpy as np
import torch
Expand All @@ -44,7 +45,7 @@
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
from vllm.v1.worker.gpu.attn_utils import (
Expand Down Expand Up @@ -213,7 +214,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.pooling_runner: PoolingRunner | None = None

# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: tuple | None = None
self.execute_model_state: ExecuteModelState | None = None

def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
Expand Down Expand Up @@ -375,16 +376,12 @@ def _dummy_run(
return None, None

assert self.execute_model_state is not None
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
input_batch = self.execute_model_state.input_batch
attn_metadata = self.execute_model_state.attn_metadata
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
self.execute_model_state = None

# dummy run the eagle speculator's propose to ensure DP/EP sync.
Expand Down Expand Up @@ -989,15 +986,14 @@ def execute_model(
aux_hidden_states = None

kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = (
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
self.execute_model_state = ExecuteModelState(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings_by_layer=slot_mappings_by_layer,
hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
kv_connector_output=kv_connector_output,
num_tokens_across_dp=num_tokens_across_dp,
)

if not self.is_last_pp_rank:
Expand All @@ -1016,16 +1012,14 @@ def sample_tokens(
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state

input_batch = self.execute_model_state.input_batch
attn_metadata = self.execute_model_state.attn_metadata
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states
kv_connector_output = self.execute_model_state.kv_connector_output
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
self.execute_model_state = None

if not self.is_last_pp_rank:
Expand Down Expand Up @@ -1116,9 +1110,9 @@ def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
# The prior execute_model call must have failed.
return None

input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = (
self.execute_model_state
)
input_batch = self.execute_model_state.input_batch
hidden_states = self.execute_model_state.hidden_states
kv_connector_output = self.execute_model_state.kv_connector_output
self.execute_model_state = None

if not self.is_last_pp_rank:
Expand Down Expand Up @@ -1164,3 +1158,13 @@ def postprocess_pool(self, input_batch: InputBatch) -> None:
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)


class ExecuteModelState(NamedTuple):
input_batch: InputBatch
attn_metadata: dict[str, Any] | None
slot_mappings_by_layer: dict[str, torch.Tensor] | None
hidden_states: torch.Tensor | IntermediateTensors
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None