diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 2b94362a808f..f18d8f320aaa 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -263,6 +263,7 @@ def __init__( decode_query_len: int, ): super().__init__(vllm_config, device, cudagraph_mode, decode_query_len) + # Used for FULL CUDA graphs. PW CUDA graphs do not use these. self.hidden_states: torch.Tensor | None = None self.aux_hidden_states: list[torch.Tensor] = [] self.use_aux_hidden_state_outputs = False @@ -326,6 +327,12 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None: **model_state.prepare_dummy_inputs(num_reqs, num_tokens), } model_output = model(**model_inputs) + + if cg_mode == CUDAGraphMode.PIECEWISE: + # PW CUDA graph internally handles the model outputs. + # No need to keep track of the hidden states. + return None + if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output else: