Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def post() -> None:
)

# accessing non-tensor attributes should not trigger wait.
assert it.kv_connector_output is None
assert it._comm_handles is not None
assert work.wait_calls == 0
assert post_calls["n"] == 0

Expand Down
7 changes: 5 additions & 2 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
KVCacheGroupSpec,
KVCacheTensor,
)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
Expand Down Expand Up @@ -279,7 +280,8 @@ def receive_prev_sampled_token_ids():
lambda: SimpleNamespace(world_size=world_size, is_last_rank=is_last_rank),
)

assert GPUModelRunner.sample_tokens(runner, None) is None
output = GPUModelRunner.sample_tokens(runner, None)
assert output in (EMPTY_MODEL_RUNNER_OUTPUT, None)
assert receive_calls == expected_calls


Expand All @@ -298,7 +300,8 @@ def test_sample_tokens_skips_pp_group_lookup_without_async_scheduling(
pytest.fail,
)

assert GPUModelRunner.sample_tokens(runner, None) is None
output = GPUModelRunner.sample_tokens(runner, None)
assert output in (EMPTY_MODEL_RUNNER_OUTPUT, None)


def test_select_common_block_size_no_valid_option():
Expand Down
9 changes: 7 additions & 2 deletions tests/v1/worker/test_gpu_model_runner_v2_eplb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch

from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.worker.gpu import eplb_utils as eplb
from vllm.v1.worker.gpu import model_runner as mrv2

Expand Down Expand Up @@ -76,7 +77,10 @@ def _make_runner(**overrides: Any) -> Any:
runner.max_num_reqs = 8
runner.max_num_tokens = 16
runner.decode_query_len = 1
runner.kv_connector = SimpleNamespace(set_disabled=lambda *_: None)
runner.kv_connector = SimpleNamespace(
set_disabled=lambda *_: None,
post_forward=lambda *_, **__: None,
)
runner.eplb = eplb.EPLBController(runner.parallel_config, runner.device)
runner.pooling_runner = None
runner.execute_model_state = None
Expand Down Expand Up @@ -183,5 +187,6 @@ def test_v2_sample_tokens_runs_eplb_on_non_last_pp_rank(monkeypatch):
),
)

assert mrv2.GPUModelRunner.sample_tokens(runner, None) is None
output = mrv2.GPUModelRunner.sample_tokens(runner, None)
assert output in (EMPTY_MODEL_RUNNER_OUTPUT, None)
assert events == ["postprocess", "eplb"]
11 changes: 0 additions & 11 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,29 @@
"""Sequence and its related classes."""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import torch

if TYPE_CHECKING:
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
else:
KVConnectorOutput = Any


# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.

Each stage also needs to handle its own kv_connector_output.
"""

tensors: dict[str, torch.Tensor]
kv_connector_output: KVConnectorOutput | None

def __init__(
self,
tensors: dict[str, torch.Tensor],
kv_connector_output: KVConnectorOutput | None = None,
) -> None:
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
self.kv_connector_output = kv_connector_output

def __getitem__(self, key: str | slice):
if isinstance(key, str):
Expand Down
14 changes: 14 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from copy import copy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple, TypeAlias

Expand Down Expand Up @@ -279,6 +280,19 @@ class ModelRunnerOutput:
# ``None`` when ``enable_return_routed_experts`` is off.
routed_experts: RoutedExpertsLists | None = None

@staticmethod
def with_kv_conn_output_only(
kv_connector_output: KVConnectorOutput | None,
) -> "ModelRunnerOutput":
"""Return ModelRunnerOutput containing the provided KVConnectorOutput,
otherwise empty. Returns None if kv_connector_output is passed as None.
"""
if kv_connector_output is None or kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output


# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
Expand Down
7 changes: 1 addition & 6 deletions vllm/v1/worker/gpu/kv_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from typing import TYPE_CHECKING

import torch
Expand Down Expand Up @@ -103,11 +102,7 @@ def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
self.pre_forward(scheduler_output)
finished_req_ids = scheduler_output.finished_req_ids
kv_connector_output = self.post_forward(finished_req_ids, wait_for_save=False)
if kv_connector_output is None or kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output)

def set_disabled(self, disabled: bool) -> None:
# Ensure that layer-wise connector hooks aren't called when disabled.
Expand Down
16 changes: 8 additions & 8 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,9 +1218,6 @@ def execute_model(

if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert output_intermediate_tensors is not None
kv_connector_output = self.kv_connector.post_forward(finished_req_ids)
output_intermediate_tensors.kv_connector_output = kv_connector_output
return output_intermediate_tensors
return None

Expand Down Expand Up @@ -1249,7 +1246,10 @@ def sample_tokens(
input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
)
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None

# Post-step KV connector related operations.
kv_connector_output = self.kv_connector.post_forward(finished_req_ids)
return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output)

# Last rank: sample tokens
sampler_output, num_sampled, num_rejected = self.sample(
Expand Down Expand Up @@ -1367,18 +1367,18 @@ def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
finished_req_ids = self.execute_model_state.finished_req_ids
self.execute_model_state = None

# Post-step KV connector related operations.
kv_connector_output = self.kv_connector.post_forward(finished_req_ids)

if not self.is_last_pp_rank:
self.postprocess_pool(input_batch)
return None
return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output)

assert self.pooling_runner is not None
pooler_output, is_valid = self.pooling_runner.pool(
hidden_states, input_batch, self.req_states
)

# Post-step KV connector related operations.
kv_connector_output = self.kv_connector.post_forward(finished_req_ids)

# Build the model runner output.
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
Expand Down
11 changes: 1 addition & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4250,7 +4250,6 @@ def execute_model(
if not get_pp_group().is_last_rank:
# Return the intermediate tensors.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.kv_connector_output = kv_connector_output
return hidden_states

Expand Down Expand Up @@ -4326,17 +4325,9 @@ def sample_tokens(
# receive sampled token ids from the last PP rank.
if self.use_async_scheduling and not get_pp_group().is_last_rank:
self._pp_receive_prev_sampled_token_ids_to_input_batch()
if not kv_connector_output:
return None # type: ignore[return-value]

# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT

output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output)

# Unpack ephemeral state.
(
Expand Down
9 changes: 1 addition & 8 deletions vllm/v1/worker/kv_connector_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Define KV connector functionality mixin for model runners.
"""

import copy
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import TYPE_CHECKING
Expand All @@ -20,7 +19,6 @@
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
Expand All @@ -47,12 +45,7 @@ def kv_connector_no_forward(
):
pass

if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT

output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output)

@staticmethod
def maybe_get_kv_connector_output(
Expand Down
Loading