-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[Core] Pipeline Parallel support for Model Runner V2 #33960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
WoosukKwon
merged 2 commits into
vllm-project:main
from
ZhanqiuHu:feature/model-runner-v2-pp
Feb 17, 2026
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,19 +3,22 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import gc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from copy import deepcopy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.config import VllmConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.config.compilation import CUDAGraphMode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.distributed.parallel_state import prepare_communication_buffer_for_model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.distributed.parallel_state import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_pp_group, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prepare_communication_buffer_for_model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.forward_context import set_forward_context | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.logger import init_logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.model_executor.model_loader import get_model_loader | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.sequence import IntermediateTensors | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -54,6 +57,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.worker.gpu.lora_utils import LoraState | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.worker.gpu.pp_handler import PPHandler, get_pp_handler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.worker.gpu.sample.output import SamplerOutput | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.worker.gpu.sample.sampler import Sampler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -175,6 +179,12 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # KV Connector if configured. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Pipeline parallelism. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_pp = self.parallel_config.pipeline_parallel_size > 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.pp_handler: PPHandler | None = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_pp_handler(self.parallel_config) if self.use_pp else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def update_max_model_len(self, max_model_len: int) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.max_model_len = max_model_len | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.req_states.max_model_len = max_model_len | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -287,7 +297,7 @@ def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @torch.inference_mode() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _dummy_run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, num_tokens: int, *args, skip_attn: bool = True, **kwargs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[torch.Tensor | None, torch.Tensor | None]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Create a dummy scheduler output. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_reqs = min(num_tokens, self.max_num_reqs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_tokens_per_request = [num_tokens // num_reqs] * num_reqs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -303,13 +313,31 @@ def _dummy_run( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Disable any use of KVConnector for dummy runs. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.kv_connector.set_disabled(True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For non-first PP ranks, create dummy intermediate_tensors. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intermediate_tensors = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_pp and not get_pp_group().is_first_rank: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intermediate_tensors = self.model.make_empty_intermediate_tensors( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size=num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=self.model_config.dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device=self.device, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Execute the model. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.execute_model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dummy_scheduler_output, dummy_run=True, skip_attn_for_dummy_run=skip_attn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dummy_scheduler_output, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intermediate_tensors=intermediate_tensors, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dummy_run=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| skip_attn_for_dummy_run=skip_attn, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.kv_connector.set_disabled(False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Non-last PP ranks don't produce output for sampling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_pp and not get_pp_group().is_last_rank: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None, None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.execute_model_state is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, input_batch, _ = self.execute_model_state | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert hidden_states is not None # Last PP rank always has hidden_states | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sample_hidden_states = hidden_states[input_batch.logits_indices] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return hidden_states, sample_hidden_states | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -342,7 +370,10 @@ def profile_run(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, sample_hidden_states = self._dummy_run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.max_num_tokens, skip_attn=True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._dummy_sampler_run(sample_hidden_states) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Only run sampler on last PP rank (non-last ranks return None). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.use_pp or get_pp_group().is_last_rank: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert sample_hidden_states is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._dummy_sampler_run(sample_hidden_states) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.do_spec_decode: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_tokens_across_dp = make_num_tokens_across_dp( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.parallel_config.data_parallel_size, self.max_num_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -378,6 +409,14 @@ def capture_model(self) -> int: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO (zhanqiu): support CUDA graph for PP. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_pp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning_once( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Skipping CUDA graph capture because pipeline parallel is " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "enabled. Pipeline parallel is currently eager-only.", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start_time = time.perf_counter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gc.collect() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -796,11 +835,10 @@ def propose_draft( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def execute_model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scheduler_output: SchedulerOutput, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intermediate_tensors: Any | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intermediate_tensors: IntermediateTensors | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dummy_run: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| skip_attn_for_dummy_run: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> ModelRunnerOutput | None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert intermediate_tensors is None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> ModelRunnerOutput | IntermediateTensors | None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not dummy_run: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Update the request states. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.finish_requests(scheduler_output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -846,8 +884,10 @@ def execute_model( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._set_active_loras(*lora_inputs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.supports_mm_inputs: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Execute the multimodal encoder. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Only first PP rank prepares multimodal embeddings. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.supports_mm_inputs and ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| not self.use_pp or get_pp_group().is_first_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mm_embeds, is_mm_embed = self.get_mm_embeddings( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scheduler_output.scheduled_encoder_inputs, input_batch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -889,6 +929,7 @@ def execute_model( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.uses_mrope: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert input_batch.mrope_positions is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| positions = input_batch.mrope_positions | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with set_forward_context( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_batch.attn_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.vllm_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -899,27 +940,71 @@ def execute_model( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| slot_mapping=input_batch.slot_mappings, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.kv_connector.pre_forward(scheduler_output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self.model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_ids=input_batch.input_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| positions=positions, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inputs_embeds=input_batch.inputs_embeds, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_pp and not get_pp_group().is_first_rank: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Non-first PP rank: forward with intermediate tensors. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert intermediate_tensors is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self.model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_ids=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| positions=positions, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inputs_embeds=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| intermediate_tensors=intermediate_tensors, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self.model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_ids=input_batch.input_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| positions=positions, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inputs_embeds=input_batch.inputs_embeds, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+943
to
+957
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggested:
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kv_connector_output = self.kv_connector.post_forward(scheduler_output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.execute_model_state = hidden_states, input_batch, kv_connector_output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_pp and not get_pp_group().is_last_rank: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Non-last PP rank: return IntermediateTensors for sending. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(hidden_states, IntermediateTensors) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states.kv_connector_output = kv_connector_output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.execute_model_state = (None, input_batch, kv_connector_output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return hidden_states | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(hidden_states, torch.Tensor) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Last rank (or no PP): hidden_states is a tensor for sampling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.execute_model_state = (hidden_states, input_batch, kv_connector_output) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @torch.inference_mode() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def sample_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, grammar_output: GrammarOutput | None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> AsyncOutput | ModelRunnerOutput: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> AsyncOutput | ModelRunnerOutput | None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.execute_model_state is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, input_batch, kv_connector_output = self.execute_model_state | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.execute_model_state = None # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Non-last PP rank: hidden_states is None because this rank produced | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # IntermediateTensors instead of final hidden states. Receive the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # sampled tokens broadcast by the last rank and update local state. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_pp and not get_pp_group().is_last_rank: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.pp_handler is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| received = self.pp_handler.maybe_receive_sampled_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_batch.num_reqs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.device, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_sample_len=self.num_speculative_steps + 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if received is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sampled, num_sampled, num_rejected = received | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.postprocess(input_batch, sampled, num_sampled, num_rejected) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Last rank: sample tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sampler_output, num_sampled, num_rejected = self.sample( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, input_batch, grammar_output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Broadcast to non-last PP ranks (handles spec decode multi-token). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_pp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.pp_handler is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.pp_handler.maybe_broadcast_sampled_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sampler_output, num_sampled, num_rejected | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model.compute_logits, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Pipeline Parallelism handler for V2 Model Runner.""" | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.distributed.parallel_state import get_pp_group | ||
| from vllm.v1.worker.gpu.sample.output import SamplerOutput | ||
|
|
||
|
|
||
| class PPHandler: | ||
| """Pipeline parallelism handler for Model Runner V2. | ||
|
|
||
| Manages sampled token synchronization between PP ranks. | ||
| Only instantiated when PP is enabled (pp_size > 1). | ||
| """ | ||
|
|
||
| def maybe_broadcast_sampled_tokens( | ||
| self, | ||
| sampler_output: SamplerOutput, | ||
| num_sampled: torch.Tensor, | ||
| num_rejected: torch.Tensor, | ||
| ) -> None: | ||
| """Broadcast sampled tokens from the last PP rank to all other ranks. | ||
|
|
||
| No-ops if this is not the last rank. | ||
|
|
||
| Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled | ||
| [num_reqs], and num_rejected [num_reqs] to support both regular decode | ||
| and speculative decoding. | ||
|
|
||
| Args: | ||
| sampler_output: SamplerOutput from sampling. | ||
| num_sampled: Number of accepted tokens per request. | ||
| num_rejected: Number of rejected tokens per request. | ||
| """ | ||
| pp = get_pp_group() | ||
| if not pp.is_last_rank: | ||
| return | ||
|
|
||
| torch.distributed.broadcast( | ||
| sampler_output.sampled_token_ids.contiguous(), | ||
| src=pp.last_rank, | ||
| group=pp.device_group, | ||
| ) | ||
| # NOTE: num_sampled/num_rejected are only needed | ||
| # for speculative decoding. | ||
| torch.distributed.broadcast( | ||
| num_sampled.contiguous(), | ||
| src=pp.last_rank, | ||
| group=pp.device_group, | ||
| ) | ||
| torch.distributed.broadcast( | ||
| num_rejected.contiguous(), | ||
| src=pp.last_rank, | ||
| group=pp.device_group, | ||
| ) | ||
|
|
||
| def maybe_receive_sampled_tokens( | ||
| self, | ||
| num_reqs: int, | ||
| device: torch.device, | ||
| max_sample_len: int = 1, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: | ||
| """Receive sampled tokens broadcast by the last PP rank. | ||
|
|
||
| Returns None if this is the last rank (which samples, not receives). | ||
|
|
||
| Args: | ||
| num_reqs: Number of requests in the batch. | ||
| device: Device to create tensors on. | ||
| max_sample_len: Maximum number of tokens sampled per request | ||
| (1 for regular decode, >1 for speculative decoding). | ||
|
|
||
| Returns: | ||
| None if called on last rank. | ||
| Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected): | ||
| - sampled_tokens: shape [num_reqs, max_sample_len] | ||
| - num_sampled: shape [num_reqs] | ||
| - num_rejected: shape [num_reqs] | ||
| """ | ||
| pp = get_pp_group() | ||
| if pp.is_last_rank: | ||
| return None | ||
|
|
||
| sampled_tokens = torch.empty( | ||
| num_reqs, max_sample_len, dtype=torch.int64, device=device | ||
| ) | ||
| torch.distributed.broadcast( | ||
| sampled_tokens, | ||
| src=pp.last_rank, | ||
| group=pp.device_group, | ||
| ) | ||
| # NOTE: num_sampled/num_rejected are only needed | ||
| # for speculative decoding. | ||
| num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=device) | ||
| torch.distributed.broadcast( | ||
| num_sampled, | ||
| src=pp.last_rank, | ||
| group=pp.device_group, | ||
| ) | ||
| num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=device) | ||
| torch.distributed.broadcast( | ||
| num_rejected, | ||
| src=pp.last_rank, | ||
| group=pp.device_group, | ||
| ) | ||
| return sampled_tokens, num_sampled, num_rejected | ||
|
|
||
|
|
||
| def get_pp_handler(parallel_config) -> PPHandler: | ||
| """Factory function to create PPHandler. | ||
|
|
||
| Must only be called when PP is enabled (pp_size > 1). | ||
| """ | ||
| assert parallel_config.pipeline_parallel_size > 1, ( | ||
| "PPHandler should not be created when pipeline parallelism is disabled." | ||
| ) | ||
| return PPHandler() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.