diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index d6b87bd710f2..5b28fa431c42 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -3,7 +3,6 @@ import gc import time from copy import deepcopy -from typing import Any import numpy as np import torch @@ -11,11 +10,15 @@ from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode -from vllm.distributed.parallel_state import prepare_communication_buffer_for_model +from vllm.distributed.parallel_state import ( + get_pp_group, + prepare_communication_buffer_for_model, +) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -54,6 +57,7 @@ from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState +from vllm.v1.worker.gpu.pp_handler import PPHandler, get_pp_handler from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.sampler import Sampler @@ -175,6 +179,12 @@ def __init__( # KV Connector if configured. self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR + # Pipeline parallelism. + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + self.pp_handler: PPHandler | None = ( + get_pp_handler(self.parallel_config) if self.use_pp else None + ) + def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len self.req_states.max_model_len = max_model_len @@ -287,7 +297,7 @@ def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None: @torch.inference_mode() def _dummy_run( self, num_tokens: int, *args, skip_attn: bool = True, **kwargs - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: # Create a dummy scheduler output. num_reqs = min(num_tokens, self.max_num_reqs) num_tokens_per_request = [num_tokens // num_reqs] * num_reqs @@ -303,13 +313,31 @@ def _dummy_run( # Disable any use of KVConnector for dummy runs. self.kv_connector.set_disabled(True) + # For non-first PP ranks, create dummy intermediate_tensors. + intermediate_tensors = None + if self.use_pp and not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + # Execute the model. self.execute_model( - dummy_scheduler_output, dummy_run=True, skip_attn_for_dummy_run=skip_attn + dummy_scheduler_output, + intermediate_tensors=intermediate_tensors, + dummy_run=True, + skip_attn_for_dummy_run=skip_attn, ) self.kv_connector.set_disabled(False) + + # Non-last PP ranks don't produce output for sampling. + if self.use_pp and not get_pp_group().is_last_rank: + return None, None + assert self.execute_model_state is not None hidden_states, input_batch, _ = self.execute_model_state + assert hidden_states is not None # Last PP rank always has hidden_states sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states @@ -342,7 +370,10 @@ def profile_run(self) -> None: hidden_states, sample_hidden_states = self._dummy_run( self.max_num_tokens, skip_attn=True ) - self._dummy_sampler_run(sample_hidden_states) + # Only run sampler on last PP rank (non-last ranks return None). + if not self.use_pp or get_pp_group().is_last_rank: + assert sample_hidden_states is not None + self._dummy_sampler_run(sample_hidden_states) if self.do_spec_decode: num_tokens_across_dp = make_num_tokens_across_dp( self.parallel_config.data_parallel_size, self.max_num_tokens @@ -378,6 +409,14 @@ def capture_model(self) -> int: ) return 0 + # TODO (zhanqiu): support CUDA graph for PP. + if self.use_pp: + logger.warning_once( + "Skipping CUDA graph capture because pipeline parallel is " + "enabled. Pipeline parallel is currently eager-only.", + ) + return 0 + start_time = time.perf_counter() gc.collect() torch.cuda.empty_cache() @@ -796,11 +835,10 @@ def propose_draft( def execute_model( self, scheduler_output: SchedulerOutput, - intermediate_tensors: Any | None = None, + intermediate_tensors: IntermediateTensors | None = None, dummy_run: bool = False, skip_attn_for_dummy_run: bool = False, - ) -> ModelRunnerOutput | None: - assert intermediate_tensors is None + ) -> ModelRunnerOutput | IntermediateTensors | None: if not dummy_run: # Update the request states. self.finish_requests(scheduler_output) @@ -846,8 +884,10 @@ def execute_model( ) self._set_active_loras(*lora_inputs) - if self.supports_mm_inputs: - # Execute the multimodal encoder. + # Only first PP rank prepares multimodal embeddings. + if self.supports_mm_inputs and ( + not self.use_pp or get_pp_group().is_first_rank + ): mm_embeds, is_mm_embed = self.get_mm_embeddings( scheduler_output.scheduled_encoder_inputs, input_batch ) @@ -889,6 +929,7 @@ def execute_model( if self.uses_mrope: assert input_batch.mrope_positions is not None positions = input_batch.mrope_positions + with set_forward_context( input_batch.attn_metadata, self.vllm_config, @@ -899,27 +940,71 @@ def execute_model( slot_mapping=input_batch.slot_mappings, ): self.kv_connector.pre_forward(scheduler_output) - hidden_states = self.model( - input_ids=input_batch.input_ids, - positions=positions, - inputs_embeds=input_batch.inputs_embeds, - ) + if self.use_pp and not get_pp_group().is_first_rank: + # Non-first PP rank: forward with intermediate tensors. + assert intermediate_tensors is not None + hidden_states = self.model( + input_ids=None, + positions=positions, + inputs_embeds=None, + intermediate_tensors=intermediate_tensors, + ) + else: + hidden_states = self.model( + input_ids=input_batch.input_ids, + positions=positions, + inputs_embeds=input_batch.inputs_embeds, + ) kv_connector_output = self.kv_connector.post_forward(scheduler_output) - self.execute_model_state = hidden_states, input_batch, kv_connector_output + + if self.use_pp and not get_pp_group().is_last_rank: + # Non-last PP rank: return IntermediateTensors for sending. + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + self.execute_model_state = (None, input_batch, kv_connector_output) + return hidden_states + + assert isinstance(hidden_states, torch.Tensor) + # Last rank (or no PP): hidden_states is a tensor for sampling. + self.execute_model_state = (hidden_states, input_batch, kv_connector_output) return None @torch.inference_mode() def sample_tokens( self, grammar_output: GrammarOutput | None - ) -> AsyncOutput | ModelRunnerOutput: + ) -> AsyncOutput | ModelRunnerOutput | None: assert self.execute_model_state is not None hidden_states, input_batch, kv_connector_output = self.execute_model_state self.execute_model_state = None # type: ignore + # Non-last PP rank: hidden_states is None because this rank produced + # IntermediateTensors instead of final hidden states. Receive the + # sampled tokens broadcast by the last rank and update local state. + if self.use_pp and not get_pp_group().is_last_rank: + assert self.pp_handler is not None + received = self.pp_handler.maybe_receive_sampled_tokens( + input_batch.num_reqs, + self.device, + max_sample_len=self.num_speculative_steps + 1, + ) + if received is not None: + sampled, num_sampled, num_rejected = received + self.postprocess(input_batch, sampled, num_sampled, num_rejected) + return None + + # Last rank: sample tokens sampler_output, num_sampled, num_rejected = self.sample( hidden_states, input_batch, grammar_output ) + + # Broadcast to non-last PP ranks (handles spec decode multi-token). + if self.use_pp: + assert self.pp_handler is not None + self.pp_handler.maybe_broadcast_sampled_tokens( + sampler_output, num_sampled, num_rejected + ) + prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs( self.model.compute_logits, hidden_states, diff --git a/vllm/v1/worker/gpu/pp_handler.py b/vllm/v1/worker/gpu/pp_handler.py new file mode 100644 index 000000000000..a254f577f6ea --- /dev/null +++ b/vllm/v1/worker/gpu/pp_handler.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pipeline Parallelism handler for V2 Model Runner.""" + +import torch + +from vllm.distributed.parallel_state import get_pp_group +from vllm.v1.worker.gpu.sample.output import SamplerOutput + + +class PPHandler: + """Pipeline parallelism handler for Model Runner V2. + + Manages sampled token synchronization between PP ranks. + Only instantiated when PP is enabled (pp_size > 1). + """ + + def maybe_broadcast_sampled_tokens( + self, + sampler_output: SamplerOutput, + num_sampled: torch.Tensor, + num_rejected: torch.Tensor, + ) -> None: + """Broadcast sampled tokens from the last PP rank to all other ranks. + + No-ops if this is not the last rank. + + Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled + [num_reqs], and num_rejected [num_reqs] to support both regular decode + and speculative decoding. + + Args: + sampler_output: SamplerOutput from sampling. + num_sampled: Number of accepted tokens per request. + num_rejected: Number of rejected tokens per request. + """ + pp = get_pp_group() + if not pp.is_last_rank: + return + + torch.distributed.broadcast( + sampler_output.sampled_token_ids.contiguous(), + src=pp.last_rank, + group=pp.device_group, + ) + # NOTE: num_sampled/num_rejected are only needed + # for speculative decoding. + torch.distributed.broadcast( + num_sampled.contiguous(), + src=pp.last_rank, + group=pp.device_group, + ) + torch.distributed.broadcast( + num_rejected.contiguous(), + src=pp.last_rank, + group=pp.device_group, + ) + + def maybe_receive_sampled_tokens( + self, + num_reqs: int, + device: torch.device, + max_sample_len: int = 1, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: + """Receive sampled tokens broadcast by the last PP rank. + + Returns None if this is the last rank (which samples, not receives). + + Args: + num_reqs: Number of requests in the batch. + device: Device to create tensors on. + max_sample_len: Maximum number of tokens sampled per request + (1 for regular decode, >1 for speculative decoding). + + Returns: + None if called on last rank. + Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected): + - sampled_tokens: shape [num_reqs, max_sample_len] + - num_sampled: shape [num_reqs] + - num_rejected: shape [num_reqs] + """ + pp = get_pp_group() + if pp.is_last_rank: + return None + + sampled_tokens = torch.empty( + num_reqs, max_sample_len, dtype=torch.int64, device=device + ) + torch.distributed.broadcast( + sampled_tokens, + src=pp.last_rank, + group=pp.device_group, + ) + # NOTE: num_sampled/num_rejected are only needed + # for speculative decoding. + num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=device) + torch.distributed.broadcast( + num_sampled, + src=pp.last_rank, + group=pp.device_group, + ) + num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=device) + torch.distributed.broadcast( + num_rejected, + src=pp.last_rank, + group=pp.device_group, + ) + return sampled_tokens, num_sampled, num_rejected + + +def get_pp_handler(parallel_config) -> PPHandler: + """Factory function to create PPHandler. + + Must only be called when PP is enabled (pp_size > 1). + """ + assert parallel_config.pipeline_parallel_size > 1, ( + "PPHandler should not be created when pipeline parallelism is disabled." + ) + return PPHandler()