diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 522e7fdbf25b..bcab2ca2d4c2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5701,6 +5701,26 @@ def _dummy_sampler_run( sampler_output = self.sampler( logits=logits, sampling_metadata=dummy_metadata ) + # Also warm forward_native (taken when generators dict is non-empty), + # but skip the extra call in 'processed_logits' / 'processed_logprobs' + # modes — there TopKTopPSampler binds forward = forward_native at + # init time, so the warmup call is redundant and only inflates peak + # memory during profile_run. + # No .clone() of logits: warmup output is discarded, so any in-place + # mutation by forward_native does not affect correctness. + if self.sampler.logprobs_mode not in ( + "processed_logits", + "processed_logprobs", + ): + self.sampler( + logits=logits, + sampling_metadata=replace( + dummy_metadata, + generators={ + 0: torch.Generator(device=self.device).manual_seed(0) + }, + ), + ) except RuntimeError as e: if "out of memory" in str(e): raise RuntimeError(