Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def main():
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
print(out_str)
print(sep_str)
f.write(out_str)
f.write(sep_str)
print(out_str, file=f)
print(sep_str, file=f)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def clear_connector_metadata(self) -> None:
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
if self._async_load and forward_context.attn_metadata is None:
# Bypass sanity check in super().start_load_kv
forward_context.attn_metadata = "None"
forward_context.attn_metadata = {}
super().start_load_kv(forward_context, **kwargs)
forward_context.attn_metadata = None
return

super().start_load_kv(forward_context, **kwargs)

Expand Down
34 changes: 30 additions & 4 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,23 @@ def test_multi_shared_storage_connector_consistency():
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage1-WORKER"][:5] == [
"register_kv_caches",
register_kv_caches_event_idx = events["storage1-WORKER"].index("register_kv_caches")
last_dummy_run_start_event_idx = find_last_index(
events["storage1-WORKER"], "build_connector_meta"
)
assert last_dummy_run_start_event_idx != -1
last_dummy_run_end_event_idx = last_dummy_run_start_event_idx + events[
"storage1-WORKER"
][last_dummy_run_start_event_idx:].index("clear_connector_metadata")
first_run_start_event_idx = last_dummy_run_end_event_idx + events[
"storage1-WORKER"
][last_dummy_run_end_event_idx:].index("bind_connector_metadata")
assert (
events["storage1-WORKER"][register_kv_caches_event_idx] == "register_kv_caches"
)
assert events["storage1-WORKER"][
first_run_start_event_idx : first_run_start_event_idx + 4
] == [
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
Expand All @@ -174,8 +189,12 @@ def test_multi_shared_storage_connector_consistency():
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage2-WORKER"][:5] == [
"register_kv_caches",
assert (
events["storage2-WORKER"][register_kv_caches_event_idx] == "register_kv_caches"
)
assert events["storage2-WORKER"][
first_run_start_event_idx : first_run_start_event_idx + 4
] == [
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
Expand Down Expand Up @@ -254,6 +273,13 @@ def get_connector_events() -> dict[str, list[str]]:
return connector_events


def find_last_index(lst, query):
for i in range(len(lst) - 1, -1, -1):
if lst[i] == query:
return i
return -1


def test_engine_id_conflict():
configs = [KVTransferConfig() for _ in range(2)]
ids = [config.engine_id for config in configs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ def update_state_after_alloc(
def build_connector_meta(
self, scheduler_output: "SchedulerOutput"
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
return (
self.connector_scheduler.build_connector_meta(scheduler_output)
if self.connector_scheduler is not None
else DecodeBenchConnectorMetadata(reqs_to_fill={})
)

def request_finished(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,11 @@ def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
return (
self.connector_scheduler.build_connector_meta(scheduler_output)
if self.connector_scheduler is not None
else NixlConnectorMetadata()
)

def request_finished(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ def update_state_after_alloc(
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
return (
self.connector_scheduler.build_connector_meta(scheduler_output)
if self.connector_scheduler is not None
else OffloadingConnectorMetadata(reqs_to_load={}, reqs_to_store={})
)

def update_connector_output(self, connector_output: KVConnectorOutput):
assert self.connector_scheduler is not None
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3471,6 +3471,7 @@ def _dummy_run(
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
),
self.maybe_get_kv_connector_dummy_run_output(),
):
outputs = self.model(
input_ids=input_ids,
Expand Down
41 changes: 33 additions & 8 deletions vllm/v1/worker/kv_connector_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import copy
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import (
TYPE_CHECKING, # noqa: UP035
)
from typing import Final

from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (
Expand All @@ -21,17 +19,29 @@
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput

logger = init_logger(__name__)

_EMPTY_SCHEDULER_OUTPUT: Final[SchedulerOutput] = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
Comment on lines +31 to +39
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of having to do this, also we probably should not maintain this structure here as it belongs with the scheduler "namespace".
Furthermore, I am not sure this pre-init has any benefit given it should only be called once during the dummy run.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that SchedulerOutput is already an input to GPUModelRunner, I think it’s acceptable for the runner to construct a dummy instance for its own internal use.
Also, the dummy run may be invoked multiple times (e.g., for model warm-up or for capturing different CUDA graphs), so having a pre-initialized SchedulerOutput keeps this path simple and consistent.

free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)


# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
class KVConnectorModelRunnerMixin:
Expand Down Expand Up @@ -73,7 +83,7 @@ def get_finished_kv_transfers(

@staticmethod
def kv_connector_no_forward(
scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
scheduler_output: SchedulerOutput, vllm_config: VllmConfig
) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with (
Expand All @@ -93,14 +103,29 @@ def kv_connector_no_forward(

@staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput",
scheduler_output: SchedulerOutput,
) -> AbstractContextManager[KVConnectorOutput | None]:
return (
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
)

@staticmethod
def maybe_get_kv_connector_dummy_run_output() -> AbstractContextManager[
KVConnectorOutput | None
]:
global _EMPTY_SCHEDULER_OUTPUT
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
meta = kv_connector.build_connector_meta(_EMPTY_SCHEDULER_OUTPUT)
_EMPTY_SCHEDULER_OUTPUT.kv_connector_metadata = meta
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: behavior does not (rightfully) match the other EMPTY_* objects which are copied to guarantee they remain EMPTY_*, might be misleading to some level

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True — unlike other EMPTY_* objects, this one is only used internally for this specific path, and the leading _ indicates it’s a module-level private constant.

return KVConnectorModelRunnerMixin._get_kv_connector_output(
_EMPTY_SCHEDULER_OUTPUT
)

return nullcontext()

# This context manager must be used within an active forward context.
# It encapsulates the entire KV connector lifecycle within execute_model
@staticmethod
Expand Down