Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions vllm/v1/worker/gpu/kv_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
self.kv_connector.start_load_kv(get_forward_context())

def post_forward(
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
self,
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True,
clear_metadata: bool = True,
) -> KVConnectorOutput | None:
if self._disabled:
return None
Expand All @@ -91,9 +94,15 @@ def post_forward(
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
self.kv_connector.clear_connector_metadata()
if clear_metadata:
self.kv_connector.clear_connector_metadata()
return output

def clear_metadata(self) -> None:
"""Clear the connector metadata. Call this after draft model runs."""
if not self._disabled:
self.kv_connector.clear_connector_metadata()

def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
if self._disabled:
return EMPTY_MODEL_RUNNER_OUTPUT
Expand Down
13 changes: 12 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3524,6 +3524,9 @@ def execute_model(

# Run the model.
# Use persistent buffers for CUDA graphs.
# When spec decode is enabled, delay clearing connector metadata
# until after draft model runs in sample_tokens.
clear_kv_metadata = self.speculative_config is None
with (
set_forward_context(
attn_metadata,
Expand All @@ -3537,7 +3540,9 @@ def execute_model(
skip_compiled=has_encoder_input,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
self.maybe_get_kv_connector_output(
scheduler_output, clear_metadata=clear_kv_metadata
) as kv_connector_output,
):
model_output = self._model_forward(
input_ids=input_ids,
Expand Down Expand Up @@ -3765,6 +3770,12 @@ def propose_draft_token_ids(sampled_token_ids):
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)

# Clear KV connector metadata after draft model runs (if spec decode).
# This was deferred from target model forward to allow draft model
# to also save its KV cache.
if self.speculative_config is not None:
self.clear_kv_connector_metadata()

with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step()

Expand Down
17 changes: 15 additions & 2 deletions vllm/v1/worker/kv_connector_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,12 @@ def kv_connector_no_forward(
@staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput",
clear_metadata: bool = True,
) -> AbstractContextManager[KVConnectorOutput | None]:
return (
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, clear_metadata=clear_metadata
)
if has_kv_transfer_group()
else nullcontext()
)
Expand All @@ -79,7 +82,9 @@ def maybe_get_kv_connector_output(
@staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput", wait_for_save: bool = True
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True,
clear_metadata: bool = True,
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()

Expand Down Expand Up @@ -108,6 +113,14 @@ def _get_kv_connector_output(
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()

if clear_metadata:
kv_connector.clear_connector_metadata()

@staticmethod
def clear_kv_connector_metadata() -> None:
"""Clear the KV connector metadata. Call after draft model runs."""
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
kv_connector.clear_connector_metadata()

@staticmethod
Expand Down