diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 2804c95d32a4..48b007da664c 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -292,7 +292,7 @@ def post() -> None: ) # accessing non-tensor attributes should not trigger wait. - assert it.kv_connector_output is None + assert it._comm_handles is not None assert work.wait_calls == 0 assert post_calls["n"] == 0 diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index ffb4d7f474bd..56811982b91d 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -39,6 +39,7 @@ KVCacheGroupSpec, KVCacheTensor, ) +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_input_batch import InputBatch @@ -279,7 +280,8 @@ def receive_prev_sampled_token_ids(): lambda: SimpleNamespace(world_size=world_size, is_last_rank=is_last_rank), ) - assert GPUModelRunner.sample_tokens(runner, None) is None + output = GPUModelRunner.sample_tokens(runner, None) + assert output in (EMPTY_MODEL_RUNNER_OUTPUT, None) assert receive_calls == expected_calls @@ -298,7 +300,8 @@ def test_sample_tokens_skips_pp_group_lookup_without_async_scheduling( pytest.fail, ) - assert GPUModelRunner.sample_tokens(runner, None) is None + output = GPUModelRunner.sample_tokens(runner, None) + assert output in (EMPTY_MODEL_RUNNER_OUTPUT, None) def test_select_common_block_size_no_valid_option(): diff --git a/tests/v1/worker/test_gpu_model_runner_v2_eplb.py b/tests/v1/worker/test_gpu_model_runner_v2_eplb.py index d68ff83c4073..c2a800bd9982 100644 --- a/tests/v1/worker/test_gpu_model_runner_v2_eplb.py +++ b/tests/v1/worker/test_gpu_model_runner_v2_eplb.py @@ -7,6 +7,7 @@ import torch +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.worker.gpu import eplb_utils as eplb from vllm.v1.worker.gpu import model_runner as mrv2 @@ -76,7 +77,10 @@ def _make_runner(**overrides: Any) -> Any: runner.max_num_reqs = 8 runner.max_num_tokens = 16 runner.decode_query_len = 1 - runner.kv_connector = SimpleNamespace(set_disabled=lambda *_: None) + runner.kv_connector = SimpleNamespace( + set_disabled=lambda *_: None, + post_forward=lambda *_, **__: None, + ) runner.eplb = eplb.EPLBController(runner.parallel_config, runner.device) runner.pooling_runner = None runner.execute_model_state = None @@ -183,5 +187,6 @@ def test_v2_sample_tokens_runs_eplb_on_non_last_pp_rank(monkeypatch): ), ) - assert mrv2.GPUModelRunner.sample_tokens(runner, None) is None + output = mrv2.GPUModelRunner.sample_tokens(runner, None) + assert output in (EMPTY_MODEL_RUNNER_OUTPUT, None) assert events == ["postprocess", "eplb"] diff --git a/vllm/sequence.py b/vllm/sequence.py index 17630623646e..c90531935bb9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,15 +3,9 @@ """Sequence and its related classes.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any import torch -if TYPE_CHECKING: - from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput -else: - KVConnectorOutput = Any - # cannot use msgspec.Struct here because Dynamo does not support it @dataclass @@ -19,24 +13,19 @@ class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. - - Each stage also needs to handle its own kv_connector_output. """ tensors: dict[str, torch.Tensor] - kv_connector_output: KVConnectorOutput | None def __init__( self, tensors: dict[str, torch.Tensor], - kv_connector_output: KVConnectorOutput | None = None, ) -> None: # manually define this function, so that # Dynamo knows `IntermediateTensors()` comes from this file. # Otherwise, dataclass will generate this function by evaluating # a string, and we will lose the information about the source file. self.tensors = tensors - self.kv_connector_output = kv_connector_output def __getitem__(self, key: str | slice): if isinstance(key, str): diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 9703dfa9e70b..9f13ad939fc8 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from copy import copy from dataclasses import dataclass, field from typing import TYPE_CHECKING, NamedTuple, TypeAlias @@ -279,6 +280,19 @@ class ModelRunnerOutput: # ``None`` when ``enable_return_routed_experts`` is off. routed_experts: RoutedExpertsLists | None = None + @staticmethod + def with_kv_conn_output_only( + kv_connector_output: KVConnectorOutput | None, + ) -> "ModelRunnerOutput": + """Return ModelRunnerOutput containing the provided KVConnectorOutput, + otherwise empty. Returns None if kv_connector_output is passed as None. + """ + if kv_connector_output is None or kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): diff --git a/vllm/v1/worker/gpu/kv_connector.py b/vllm/v1/worker/gpu/kv_connector.py index 847633109844..cdacb36e5833 100644 --- a/vllm/v1/worker/gpu/kv_connector.py +++ b/vllm/v1/worker/gpu/kv_connector.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy from typing import TYPE_CHECKING import torch @@ -103,11 +102,7 @@ def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: self.pre_forward(scheduler_output) finished_req_ids = scheduler_output.finished_req_ids kv_connector_output = self.post_forward(finished_req_ids, wait_for_save=False) - if kv_connector_output is None or kv_connector_output.is_empty(): - return EMPTY_MODEL_RUNNER_OUTPUT - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output - return output + return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output) def set_disabled(self, disabled: bool) -> None: # Ensure that layer-wise connector hooks aren't called when disabled. diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index eae47bf4232c..5cba66e5c9f9 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -1218,9 +1218,6 @@ def execute_model( if not self.is_last_pp_rank: # Non-last PP rank: return IntermediateTensors for sending. - assert output_intermediate_tensors is not None - kv_connector_output = self.kv_connector.post_forward(finished_req_ids) - output_intermediate_tensors.kv_connector_output = kv_connector_output return output_intermediate_tensors return None @@ -1249,7 +1246,10 @@ def sample_tokens( input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1 ) self.postprocess(input_batch, sampled, num_sampled, num_rejected) - return None + + # Post-step KV connector related operations. + kv_connector_output = self.kv_connector.post_forward(finished_req_ids) + return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output) # Last rank: sample tokens sampler_output, num_sampled, num_rejected = self.sample( @@ -1367,18 +1367,18 @@ def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None: finished_req_ids = self.execute_model_state.finished_req_ids self.execute_model_state = None + # Post-step KV connector related operations. + kv_connector_output = self.kv_connector.post_forward(finished_req_ids) + if not self.is_last_pp_rank: self.postprocess_pool(input_batch) - return None + return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output) assert self.pooling_runner is not None pooler_output, is_valid = self.pooling_runner.pool( hidden_states, input_batch, self.req_states ) - # Post-step KV connector related operations. - kv_connector_output = self.kv_connector.post_forward(finished_req_ids) - # Build the model runner output. model_runner_output = ModelRunnerOutput( req_ids=input_batch.req_ids, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 890abd53face..779d73921e83 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4250,7 +4250,6 @@ def execute_model( if not get_pp_group().is_last_rank: # Return the intermediate tensors. assert isinstance(hidden_states, IntermediateTensors) - hidden_states.kv_connector_output = kv_connector_output self.kv_connector_output = kv_connector_output return hidden_states @@ -4326,17 +4325,9 @@ def sample_tokens( # receive sampled token ids from the last PP rank. if self.use_async_scheduling and not get_pp_group().is_last_rank: self._pp_receive_prev_sampled_token_ids_to_input_batch() - if not kv_connector_output: - return None # type: ignore[return-value] - # In case of PP with kv transfer, we need to pass through the # kv_connector_output - if kv_connector_output.is_empty(): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output - return output + return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output) # Unpack ephemeral state. ( diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 4fc1aff94fed..797e59c02909 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -4,7 +4,6 @@ Define KV connector functionality mixin for model runners. """ -import copy from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import TYPE_CHECKING @@ -20,7 +19,6 @@ from vllm.v1.attention.backend import AttentionBackend from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.outputs import ( - EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, ModelRunnerOutput, ) @@ -47,12 +45,7 @@ def kv_connector_no_forward( ): pass - if kv_connector_output.is_empty(): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output - return output + return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output) @staticmethod def maybe_get_kv_connector_output(