diff --git a/vllm/v1/worker/gpu/kv_connector.py b/vllm/v1/worker/gpu/kv_connector.py index 91f4d34296bb..7e4e27e1f234 100644 --- a/vllm/v1/worker/gpu/kv_connector.py +++ b/vllm/v1/worker/gpu/kv_connector.py @@ -77,7 +77,10 @@ def pre_forward(self, scheduler_output: "SchedulerOutput") -> None: self.kv_connector.start_load_kv(get_forward_context()) def post_forward( - self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True + self, + scheduler_output: "SchedulerOutput", + wait_for_save: bool = True, + clear_metadata: bool = True, ) -> KVConnectorOutput | None: if self._disabled: return None @@ -91,9 +94,15 @@ def post_forward( output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors() output.kv_connector_stats = self.kv_connector.get_kv_connector_stats() output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events() - self.kv_connector.clear_connector_metadata() + if clear_metadata: + self.kv_connector.clear_connector_metadata() return output + def clear_metadata(self) -> None: + """Clear the connector metadata. Call this after draft model runs.""" + if not self._disabled: + self.kv_connector.clear_connector_metadata() + def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: if self._disabled: return EMPTY_MODEL_RUNNER_OUTPUT diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9ef8584c7f86..3a354b81864f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3524,6 +3524,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. + # When spec decode is enabled, delay clearing connector metadata + # until after draft model runs in sample_tokens. + clear_kv_metadata = self.speculative_config is None with ( set_forward_context( attn_metadata, @@ -3537,7 +3540,9 @@ def execute_model( skip_compiled=has_encoder_input, ), record_function_or_nullcontext("gpu_model_runner: forward"), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, clear_metadata=clear_kv_metadata + ) as kv_connector_output, ): model_output = self._model_forward( input_ids=input_ids, @@ -3765,6 +3770,12 @@ def propose_draft_token_ids(sampled_token_ids): # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) + # Clear KV connector metadata after draft model runs (if spec decode). + # This was deferred from target model forward to allow draft model + # to also save its KV cache. + if self.speculative_config is not None: + self.clear_kv_connector_metadata() + with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 0556c3e6e41c..2e2f64b2584c 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -67,9 +67,12 @@ def kv_connector_no_forward( @staticmethod def maybe_get_kv_connector_output( scheduler_output: "SchedulerOutput", + clear_metadata: bool = True, ) -> AbstractContextManager[KVConnectorOutput | None]: return ( - KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) + KVConnectorModelRunnerMixin._get_kv_connector_output( + scheduler_output, clear_metadata=clear_metadata + ) if has_kv_transfer_group() else nullcontext() ) @@ -79,7 +82,9 @@ def maybe_get_kv_connector_output( @staticmethod @contextmanager def _get_kv_connector_output( - scheduler_output: "SchedulerOutput", wait_for_save: bool = True + scheduler_output: "SchedulerOutput", + wait_for_save: bool = True, + clear_metadata: bool = True, ) -> Generator[KVConnectorOutput, None, None]: output = KVConnectorOutput() @@ -108,6 +113,14 @@ def _get_kv_connector_output( output.kv_connector_stats = kv_connector.get_kv_connector_stats() output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events() + if clear_metadata: + kv_connector.clear_connector_metadata() + + @staticmethod + def clear_kv_connector_metadata() -> None: + """Clear the KV connector metadata. Call after draft model runs.""" + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() kv_connector.clear_connector_metadata() @staticmethod