diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 2b94362a808f..d31e2853e32b 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -11,11 +11,16 @@ from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode -from vllm.distributed.parallel_state import graph_capture, is_global_first_rank +from vllm.distributed.parallel_state import ( + get_pp_group, + graph_capture, + is_global_first_rank, +) from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.offloader.base import get_offloader from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer from vllm.v1.worker.gpu.block_table import BlockTables @@ -87,7 +92,15 @@ def __init__( assert self.compilation_config is not None self.cudagraph_mode = cudagraph_mode self.decode_query_len = decode_query_len + self.dp_size = vllm_config.parallel_config.data_parallel_size + self.pp_size = vllm_config.parallel_config.pipeline_parallel_size + if self.pp_size > 1: + self.is_first_pp_rank = get_pp_group().is_first_rank + self.is_last_pp_rank = get_pp_group().is_last_rank + else: + self.is_first_pp_rank = True + self.is_last_pp_rank = True self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {} self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None @@ -263,15 +276,18 @@ def __init__( decode_query_len: int, ): super().__init__(vllm_config, device, cudagraph_mode, decode_query_len) + # Used for FULL CUDA graphs. PW CUDA graphs do not use these. self.hidden_states: torch.Tensor | None = None self.aux_hidden_states: list[torch.Tensor] = [] self.use_aux_hidden_state_outputs = False + self.intermediate_tensors: IntermediateTensors | None = None def capture( self, model: nn.Module, model_state: ModelState, input_buffers: InputBuffers, + intermediate_tensors: IntermediateTensors | None, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, @@ -292,6 +308,12 @@ def create_forward_fn( if self.dp_size > 1 else None ) + if not self.is_first_pp_rank: + assert intermediate_tensors is not None + intermediate_tensors_sliced = intermediate_tensors[:num_tokens] + else: + intermediate_tensors_sliced = None + attn_metadata, slot_mappings = prepare_inputs_to_capture( num_reqs, num_tokens, @@ -320,26 +342,48 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None: model_inputs = { "input_ids": input_buffers.input_ids[:num_tokens], "positions": input_buffers.positions[:num_tokens], - # TODO: Pass intermediate_tensors for PP CUDA graph - # support (https://github.com/vllm-project/vllm/pull/35162). - "intermediate_tensors": None, + "intermediate_tensors": intermediate_tensors_sliced, **model_state.prepare_dummy_inputs(num_reqs, num_tokens), } model_output = model(**model_inputs) - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output + + if cg_mode == CUDAGraphMode.PIECEWISE: + # PW CUDA graph internally handles the model outputs. + # No need to keep track of the hidden states. + return None + + if self.is_last_pp_rank: + # Last PP rank (common case). + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = [] + if self.hidden_states is None: + self.hidden_states = torch.empty_like(hidden_states) + self.hidden_states[:num_tokens] = hidden_states + if ( + self.use_aux_hidden_state_outputs + and not self.aux_hidden_states + ): + self.aux_hidden_states = [ + torch.empty_like(x) for x in aux_hidden_states + ] + for i, aux in enumerate(aux_hidden_states): + self.aux_hidden_states[i][:num_tokens] = aux else: - hidden_states = model_output - aux_hidden_states = [] - if self.hidden_states is None: - self.hidden_states = torch.empty_like(hidden_states) - if self.use_aux_hidden_state_outputs and not self.aux_hidden_states: - self.aux_hidden_states = [ - torch.empty_like(x) for x in aux_hidden_states - ] - self.hidden_states[:num_tokens] = hidden_states - for i, aux in enumerate(aux_hidden_states): - self.aux_hidden_states[i][:num_tokens] = aux + # Non-last PP rank. + intermediate_tensors = model_output + assert isinstance(intermediate_tensors, IntermediateTensors) + if self.intermediate_tensors is None: + self.intermediate_tensors = IntermediateTensors( + { + k: torch.empty_like(v) + for k, v in intermediate_tensors.tensors.items() + } + ) + for k, v in intermediate_tensors.tensors.items(): + self.intermediate_tensors[k][:num_tokens] = v return forward_fn @@ -347,9 +391,13 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None: def run_fullgraph( self, desc: BatchExecutionDescriptor - ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]] | IntermediateTensors: """Replay a captured FULL cudagraph and return hidden states.""" super().run_fullgraph(desc) + if not self.is_last_pp_rank: + assert self.intermediate_tensors is not None + return self.intermediate_tensors[: desc.num_tokens] + assert self.hidden_states is not None hidden_states = self.hidden_states[: desc.num_tokens] if not self.use_aux_hidden_state_outputs: diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index d10530c95884..313535a63734 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -140,6 +140,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): else: self.is_first_pp_rank = True self.is_last_pp_rank = True + # Persistent buffer for intermediate tensors (non-first PP ranks). + self.intermediate_tensors: IntermediateTensors | None = None # Data parallelism. self.dp_size = self.parallel_config.data_parallel_size @@ -301,6 +303,17 @@ def load_model(self, *args, **kwargs) -> None: if self.is_pooling_model and self.is_last_pp_rank: self.pooling_runner = PoolingRunner(self.model) + if not self.is_first_pp_rank: + # For non-first PP ranks, create intermediate tensors sized + # for the max capture size so they can be sliced per batch. + # Save as persistent member so runtime can copy received data + # into the same addresses that the CUDA graphs captured. + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + def get_model(self) -> nn.Module: return self.model @@ -396,14 +409,11 @@ def _dummy_run( # Disable any use of KVConnector for dummy runs. self.kv_connector.set_disabled(True) - # For non-first PP ranks, create dummy intermediate_tensors. + # Get the intermediate tensors for the dummy run. intermediate_tensors = None if not self.is_first_pp_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=num_tokens, - dtype=self.model_config.dtype, - device=self.device, - ) + assert self.intermediate_tensors is not None + intermediate_tensors = self.intermediate_tensors[:num_tokens] # Execute the model. self.execute_model( @@ -528,14 +538,6 @@ def capture_model(self) -> int: ) return 0 - # TODO (zhanqiu): support CUDA graph for PP. - if self.use_pp: - logger.warning_once( - "Skipping CUDA graph capture because pipeline parallel is " - "enabled. Pipeline parallel is currently eager-only.", - ) - return 0 - start_time = time.perf_counter() gc.collect() torch.accelerator.empty_cache() @@ -546,6 +548,7 @@ def capture_model(self) -> int: self.model, self.model_state, self.input_buffers, + self.intermediate_tensors, self.block_tables, self.attn_groups, self.kv_cache_config, @@ -1010,7 +1013,6 @@ def execute_model( "input_ids": input_batch.input_ids, "positions": input_batch.positions, "inputs_embeds": inputs_embeds, - "intermediate_tensors": intermediate_tensors, # NOTE: Values returned by `prepare_inputs` will override the default # values above. **self.model_state.prepare_inputs(input_batch, self.req_states), @@ -1019,7 +1021,19 @@ def execute_model( # Update for non-first PP ranks. model_inputs["input_ids"] = None model_inputs["inputs_embeds"] = None + + # Prepare the intermediate tensors. assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + n = input_batch.num_tokens_after_padding + intermediate_tensors = IntermediateTensors( + { + k: v[:n].copy_(intermediate_tensors.tensors[k][:n]) + for k, v in self.intermediate_tensors.tensors.items() + }, + intermediate_tensors.kv_connector_output, + ) + model_inputs["intermediate_tensors"] = intermediate_tensors # Run model. if batch_desc.cg_mode == CUDAGraphMode.FULL: @@ -1028,11 +1042,6 @@ def execute_model( # because they are already copied to the CUDA graph input buffers. self.kv_connector.pre_forward(scheduler_output) model_output = self.cudagraph_manager.run_fullgraph(batch_desc) - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None else: # For piecewise and eager mode, just call model(). batch_descriptor = BatchDescriptor( @@ -1052,11 +1061,21 @@ def execute_model( ): self.kv_connector.pre_forward(scheduler_output) model_output = self.model(**model_inputs) - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None + + if self.is_last_pp_rank: + if self.use_aux_hidden_state_outputs: + assert isinstance(model_output, tuple) + hidden_states, aux_hidden_states = model_output + else: + assert isinstance(model_output, torch.Tensor) + hidden_states = model_output + aux_hidden_states = None + output_intermediate_tensors = None + else: + assert isinstance(model_output, IntermediateTensors) + hidden_states = None + aux_hidden_states = None + output_intermediate_tensors = model_output kv_connector_output = self.kv_connector.post_forward(scheduler_output) self.execute_model_state = ExecuteModelState( @@ -1071,11 +1090,9 @@ def execute_model( if not self.is_last_pp_rank: # Non-last PP rank: return IntermediateTensors for sending. - assert isinstance(hidden_states, IntermediateTensors) - hidden_states.kv_connector_output = kv_connector_output - return hidden_states - # Last rank (or no PP): hidden_states is a tensor for sampling. - assert isinstance(hidden_states, torch.Tensor) + assert output_intermediate_tensors is not None + output_intermediate_tensors.kv_connector_output = kv_connector_output + return output_intermediate_tensors return None @torch.inference_mode() @@ -1259,7 +1276,7 @@ class ExecuteModelState(NamedTuple): input_batch: InputBatch attn_metadata: dict[str, Any] | None slot_mappings_by_layer: dict[str, torch.Tensor] | None - hidden_states: torch.Tensor | IntermediateTensors + hidden_states: torch.Tensor | None aux_hidden_states: list[torch.Tensor] | None kv_connector_output: KVConnectorOutput | None num_tokens_across_dp: torch.Tensor | None