diff --git a/examples/offline_inference/kv_load_failure_recovery/decode_example.py b/examples/offline_inference/kv_load_failure_recovery/decode_example.py index 69523f56eace..805502275197 100644 --- a/examples/offline_inference/kv_load_failure_recovery/decode_example.py +++ b/examples/offline_inference/kv_load_failure_recovery/decode_example.py @@ -77,8 +77,8 @@ def main(): out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}" print(out_str) print(sep_str) - f.write(out_str) - f.write(sep_str) + print(out_str, file=f) + print(sep_str, file=f) if __name__ == "__main__": diff --git a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py index 5b2acea4c945..417cac26cfb2 100644 --- a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py +++ b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py @@ -75,7 +75,10 @@ def clear_connector_metadata(self) -> None: def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None: if self._async_load and forward_context.attn_metadata is None: # Bypass sanity check in super().start_load_kv - forward_context.attn_metadata = "None" + forward_context.attn_metadata = {} + super().start_load_kv(forward_context, **kwargs) + forward_context.attn_metadata = None + return super().start_load_kv(forward_context, **kwargs) diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 6748532afd97..63dcb5feb197 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -162,8 +162,23 @@ def test_multi_shared_storage_connector_consistency(): "update_state_after_alloc num_blocks=[0] 0", "build_connector_meta", ] - assert events["storage1-WORKER"][:5] == [ - "register_kv_caches", + register_kv_caches_event_idx = events["storage1-WORKER"].index("register_kv_caches") + last_dummy_run_start_event_idx = find_last_index( + events["storage1-WORKER"], "build_connector_meta" + ) + assert last_dummy_run_start_event_idx != -1 + last_dummy_run_end_event_idx = last_dummy_run_start_event_idx + events[ + "storage1-WORKER" + ][last_dummy_run_start_event_idx:].index("clear_connector_metadata") + first_run_start_event_idx = last_dummy_run_end_event_idx + events[ + "storage1-WORKER" + ][last_dummy_run_end_event_idx:].index("bind_connector_metadata") + assert ( + events["storage1-WORKER"][register_kv_caches_event_idx] == "register_kv_caches" + ) + assert events["storage1-WORKER"][ + first_run_start_event_idx : first_run_start_event_idx + 4 + ] == [ "bind_connector_metadata", "start_load_kv", "wait_for_layer_load", @@ -174,8 +189,12 @@ def test_multi_shared_storage_connector_consistency(): "update_state_after_alloc num_blocks=[0] 0", "build_connector_meta", ] - assert events["storage2-WORKER"][:5] == [ - "register_kv_caches", + assert ( + events["storage2-WORKER"][register_kv_caches_event_idx] == "register_kv_caches" + ) + assert events["storage2-WORKER"][ + first_run_start_event_idx : first_run_start_event_idx + 4 + ] == [ "bind_connector_metadata", "start_load_kv", "wait_for_layer_load", @@ -254,6 +273,13 @@ def get_connector_events() -> dict[str, list[str]]: return connector_events +def find_last_index(lst, query): + for i in range(len(lst) - 1, -1, -1): + if lst[i] == query: + return i + return -1 + + def test_engine_id_conflict(): configs = [KVTransferConfig() for _ in range(2)] ids = [config.engine_id for config in configs] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py index ca251cd0c6eb..81d75254ad24 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py @@ -146,8 +146,11 @@ def update_state_after_alloc( def build_connector_meta( self, scheduler_output: "SchedulerOutput" ) -> KVConnectorMetadata: - assert self.connector_scheduler is not None - return self.connector_scheduler.build_connector_meta(scheduler_output) + return ( + self.connector_scheduler.build_connector_meta(scheduler_output) + if self.connector_scheduler is not None + else DecodeBenchConnectorMetadata(reqs_to_fill={}) + ) def request_finished( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 72fcb5cd5bb7..93da1b2b60ed 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -207,8 +207,11 @@ def build_connector_meta( self, scheduler_output: SchedulerOutput, ) -> KVConnectorMetadata: - assert self.connector_scheduler is not None - return self.connector_scheduler.build_connector_meta(scheduler_output) + return ( + self.connector_scheduler.build_connector_meta(scheduler_output) + if self.connector_scheduler is not None + else NixlConnectorMetadata() + ) def request_finished( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 6d4ffc152de9..2a3b998c7ced 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -102,8 +102,11 @@ def update_state_after_alloc( def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: - assert self.connector_scheduler is not None - return self.connector_scheduler.build_connector_meta(scheduler_output) + return ( + self.connector_scheduler.build_connector_meta(scheduler_output) + if self.connector_scheduler is not None + else OffloadingConnectorMetadata(reqs_to_load={}, reqs_to_store={}) + ) def update_connector_output(self, connector_output: KVConnectorOutput): assert self.connector_scheduler is not None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e350988456f1..295872de127b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3471,6 +3471,7 @@ def _dummy_run( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ), + self.maybe_get_kv_connector_dummy_run_output(), ): outputs = self.model( input_ids=input_ids, diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index db037a9fccd5..1b8ea240f5ac 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -7,9 +7,7 @@ import copy from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import ( - TYPE_CHECKING, # noqa: UP035 -) +from typing import Final from vllm.config import VllmConfig from vllm.distributed.kv_transfer import ( @@ -21,17 +19,29 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, ModelRunnerOutput, ) -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - logger = init_logger(__name__) +_EMPTY_SCHEDULER_OUTPUT: Final[SchedulerOutput] = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids=[], + grammar_bitmask=None, +) + # Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) class KVConnectorModelRunnerMixin: @@ -73,7 +83,7 @@ def get_finished_kv_transfers( @staticmethod def kv_connector_no_forward( - scheduler_output: "SchedulerOutput", vllm_config: VllmConfig + scheduler_output: SchedulerOutput, vllm_config: VllmConfig ) -> ModelRunnerOutput: # KV send/recv even if no work to do. with ( @@ -93,7 +103,7 @@ def kv_connector_no_forward( @staticmethod def maybe_get_kv_connector_output( - scheduler_output: "SchedulerOutput", + scheduler_output: SchedulerOutput, ) -> AbstractContextManager[KVConnectorOutput | None]: return ( KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) @@ -101,6 +111,21 @@ def maybe_get_kv_connector_output( else nullcontext() ) + @staticmethod + def maybe_get_kv_connector_dummy_run_output() -> AbstractContextManager[ + KVConnectorOutput | None + ]: + global _EMPTY_SCHEDULER_OUTPUT + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + meta = kv_connector.build_connector_meta(_EMPTY_SCHEDULER_OUTPUT) + _EMPTY_SCHEDULER_OUTPUT.kv_connector_metadata = meta + return KVConnectorModelRunnerMixin._get_kv_connector_output( + _EMPTY_SCHEDULER_OUTPUT + ) + + return nullcontext() + # This context manager must be used within an active forward context. # It encapsulates the entire KV connector lifecycle within execute_model @staticmethod