Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 102 additions & 17 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
import gc
import time
from copy import deepcopy
from typing import Any

import numpy as np
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import prepare_communication_buffer_for_model
from vllm.distributed.parallel_state import (
get_pp_group,
prepare_communication_buffer_for_model,
)
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
Expand Down Expand Up @@ -54,6 +57,7 @@
from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.pp_handler import PPHandler, get_pp_handler
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler
Expand Down Expand Up @@ -175,6 +179,12 @@ def __init__(
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

# Pipeline parallelism.
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.pp_handler: PPHandler | None = (
get_pp_handler(self.parallel_config) if self.use_pp else None
)

def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
Expand Down Expand Up @@ -287,7 +297,7 @@ def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
@torch.inference_mode()
def _dummy_run(
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
Expand All @@ -303,13 +313,31 @@ def _dummy_run(
# Disable any use of KVConnector for dummy runs.
self.kv_connector.set_disabled(True)

# For non-first PP ranks, create dummy intermediate_tensors.
intermediate_tensors = None
if self.use_pp and not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)

# Execute the model.
self.execute_model(
dummy_scheduler_output, dummy_run=True, skip_attn_for_dummy_run=skip_attn
dummy_scheduler_output,
intermediate_tensors=intermediate_tensors,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
)
self.kv_connector.set_disabled(False)

# Non-last PP ranks don't produce output for sampling.
if self.use_pp and not get_pp_group().is_last_rank:
return None, None

assert self.execute_model_state is not None
hidden_states, input_batch, _ = self.execute_model_state
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states

Expand Down Expand Up @@ -342,7 +370,10 @@ def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, skip_attn=True
)
self._dummy_sampler_run(sample_hidden_states)
# Only run sampler on last PP rank (non-last ranks return None).
if not self.use_pp or get_pp_group().is_last_rank:
assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
Expand Down Expand Up @@ -378,6 +409,14 @@ def capture_model(self) -> int:
)
return 0

# TODO (zhanqiu): support CUDA graph for PP.
if self.use_pp:
logger.warning_once(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only.",
)
return 0

start_time = time.perf_counter()
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -796,11 +835,10 @@ def propose_draft(
def execute_model(
self,
scheduler_output: SchedulerOutput,
intermediate_tensors: Any | None = None,
intermediate_tensors: IntermediateTensors | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> ModelRunnerOutput | None:
assert intermediate_tensors is None
) -> ModelRunnerOutput | IntermediateTensors | None:
if not dummy_run:
# Update the request states.
self.finish_requests(scheduler_output)
Expand Down Expand Up @@ -846,8 +884,10 @@ def execute_model(
)
self._set_active_loras(*lora_inputs)

if self.supports_mm_inputs:
# Execute the multimodal encoder.
# Only first PP rank prepares multimodal embeddings.
if self.supports_mm_inputs and (
not self.use_pp or get_pp_group().is_first_rank
):
mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch
)
Expand Down Expand Up @@ -889,6 +929,7 @@ def execute_model(
if self.uses_mrope:
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions

with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
Expand All @@ -899,27 +940,71 @@ def execute_model(
slot_mapping=input_batch.slot_mappings,
):
self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
if self.use_pp and not get_pp_group().is_first_rank:
# Non-first PP rank: forward with intermediate tensors.
assert intermediate_tensors is not None
hidden_states = self.model(
input_ids=None,
positions=positions,
inputs_embeds=None,
intermediate_tensors=intermediate_tensors,
)
else:
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
Comment on lines +943 to +957
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggested:

Suggested change
if self.use_pp and not get_pp_group().is_first_rank:
# Non-first PP rank: forward with intermediate tensors.
assert intermediate_tensors is not None
hidden_states = self.model(
input_ids=None,
positions=positions,
inputs_embeds=None,
intermediate_tensors=intermediate_tensors,
)
else:
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
if get_pp_group().is_first_rank:
input_ids = input_batch.input_ids
inputs_embeds = input_batch.inputs_embeds
else:
# Non-first PP rank: forward with intermediate tensors.
input_ids, inputs_embeds = None, None
assert intermediate_tensors is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)


kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = hidden_states, input_batch, kv_connector_output

if self.use_pp and not get_pp_group().is_last_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, input_batch, kv_connector_output)
return hidden_states

assert isinstance(hidden_states, torch.Tensor)
# Last rank (or no PP): hidden_states is a tensor for sampling.
self.execute_model_state = (hidden_states, input_batch, kv_connector_output)
return None

@torch.inference_mode()
def sample_tokens(
self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput:
) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None
hidden_states, input_batch, kv_connector_output = self.execute_model_state
self.execute_model_state = None # type: ignore

# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if self.use_pp and not get_pp_group().is_last_rank:
assert self.pp_handler is not None
received = self.pp_handler.maybe_receive_sampled_tokens(
input_batch.num_reqs,
self.device,
max_sample_len=self.num_speculative_steps + 1,
)
if received is not None:
sampled, num_sampled, num_rejected = received
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None

# Last rank: sample tokens
sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, grammar_output
)

# Broadcast to non-last PP ranks (handles spec decode multi-token).
if self.use_pp:
assert self.pp_handler is not None
self.pp_handler.maybe_broadcast_sampled_tokens(
sampler_output, num_sampled, num_rejected
)

prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
self.model.compute_logits,
hidden_states,
Expand Down
119 changes: 119 additions & 0 deletions vllm/v1/worker/gpu/pp_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism handler for V2 Model Runner."""

import torch

from vllm.distributed.parallel_state import get_pp_group
from vllm.v1.worker.gpu.sample.output import SamplerOutput


class PPHandler:
"""Pipeline parallelism handler for Model Runner V2.

Manages sampled token synchronization between PP ranks.
Only instantiated when PP is enabled (pp_size > 1).
"""

def maybe_broadcast_sampled_tokens(
self,
sampler_output: SamplerOutput,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> None:
"""Broadcast sampled tokens from the last PP rank to all other ranks.

No-ops if this is not the last rank.

Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled
[num_reqs], and num_rejected [num_reqs] to support both regular decode
and speculative decoding.

Args:
sampler_output: SamplerOutput from sampling.
num_sampled: Number of accepted tokens per request.
num_rejected: Number of rejected tokens per request.
"""
pp = get_pp_group()
if not pp.is_last_rank:
return

torch.distributed.broadcast(
sampler_output.sampled_token_ids.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
torch.distributed.broadcast(
num_sampled.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)
torch.distributed.broadcast(
num_rejected.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)

def maybe_receive_sampled_tokens(
self,
num_reqs: int,
device: torch.device,
max_sample_len: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
"""Receive sampled tokens broadcast by the last PP rank.

Returns None if this is the last rank (which samples, not receives).

Args:
num_reqs: Number of requests in the batch.
device: Device to create tensors on.
max_sample_len: Maximum number of tokens sampled per request
(1 for regular decode, >1 for speculative decoding).

Returns:
None if called on last rank.
Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected):
- sampled_tokens: shape [num_reqs, max_sample_len]
- num_sampled: shape [num_reqs]
- num_rejected: shape [num_reqs]
"""
pp = get_pp_group()
if pp.is_last_rank:
return None

sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=device
)
torch.distributed.broadcast(
sampled_tokens,
src=pp.last_rank,
group=pp.device_group,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=device)
torch.distributed.broadcast(
num_sampled,
src=pp.last_rank,
group=pp.device_group,
)
num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=device)
torch.distributed.broadcast(
num_rejected,
src=pp.last_rank,
group=pp.device_group,
)
return sampled_tokens, num_sampled, num_rejected


def get_pp_handler(parallel_config) -> PPHandler:
"""Factory function to create PPHandler.

Must only be called when PP is enabled (pp_size > 1).
"""
assert parallel_config.pipeline_parallel_size > 1, (
"PPHandler should not be created when pipeline parallelism is disabled."
)
return PPHandler()