Skip to content
Merged
69 changes: 54 additions & 15 deletions vllm/v1/worker/gpu/cudagraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.distributed.parallel_state import (
get_pp_group,
graph_capture,
is_global_first_rank,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables
Expand Down Expand Up @@ -87,7 +92,15 @@ def __init__(
assert self.compilation_config is not None
self.cudagraph_mode = cudagraph_mode
self.decode_query_len = decode_query_len

self.dp_size = vllm_config.parallel_config.data_parallel_size
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
if self.pp_size > 1:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True

self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
Expand Down Expand Up @@ -267,12 +280,14 @@ def __init__(
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
self.use_aux_hidden_state_outputs = False
self.intermediate_tensors: IntermediateTensors | None = None

def capture(
self,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
intermediate_tensors: IntermediateTensors | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
Expand All @@ -293,6 +308,19 @@ def create_forward_fn(
if self.dp_size > 1
else None
)

model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
if not self.is_first_pp_rank:
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
assert intermediate_tensors is not None
model_inputs["intermediate_tensors"] = intermediate_tensors[:num_tokens]

attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
Expand All @@ -318,45 +346,56 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None:
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
# TODO: Pass intermediate_tensors for PP CUDA graph
# support (https://github.com/vllm-project/vllm/pull/35162).
"intermediate_tensors": None,
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
model_output = model(**model_inputs)

if cg_mode == CUDAGraphMode.PIECEWISE:
# PW CUDA graph internally handles the model outputs.
# No need to keep track of the hidden states.
return None
if cg_mode == CUDAGraphMode.PIECEWISE:
# PW CUDA graph internally handles the model outputs.
# No need to keep track of the hidden states.
return None

if self.is_last_pp_rank:
# Last PP rank (common case).
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = []
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
self.hidden_states[:num_tokens] = hidden_states
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [
torch.empty_like(x) for x in aux_hidden_states
]
self.hidden_states[:num_tokens] = hidden_states
for i, aux in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux
else:
# Non-last PP rank.
intermediate_tensors = model_output
assert isinstance(intermediate_tensors, IntermediateTensors)
if self.intermediate_tensors is None:
self.intermediate_tensors = IntermediateTensors(
{
k: torch.empty_like(v)
for k, v in intermediate_tensors.tensors.items()
}
)
for k, v in intermediate_tensors.tensors.items():
self.intermediate_tensors[k][:num_tokens] = v

return forward_fn

super().capture(create_forward_fn, progress_bar_desc)

def run_fullgraph(
self, desc: BatchExecutionDescriptor
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]] | IntermediateTensors:
"""Replay a captured FULL cudagraph and return hidden states."""
super().run_fullgraph(desc)
if not self.is_last_pp_rank:
assert self.intermediate_tensors is not None
return self.intermediate_tensors[: desc.num_tokens]

assert self.hidden_states is not None
hidden_states = self.hidden_states[: desc.num_tokens]
if not self.use_aux_hidden_state_outputs:
Expand Down
79 changes: 48 additions & 31 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Persistent buffer for intermediate tensors (non-first PP ranks).
self.intermediate_tensors: IntermediateTensors | None = None

# Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size
Expand Down Expand Up @@ -301,6 +303,17 @@ def load_model(self, *args, **kwargs) -> None:
if self.is_pooling_model and self.is_last_pp_rank:
self.pooling_runner = PoolingRunner(self.model)

if not self.is_first_pp_rank:
# For non-first PP ranks, create intermediate tensors sized
# for the max capture size so they can be sliced per batch.
# Save as persistent member so runtime can copy received data
# into the same addresses that the CUDA graphs captured.
self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)

def get_model(self) -> nn.Module:
return self.model

Expand Down Expand Up @@ -396,14 +409,11 @@ def _dummy_run(
# Disable any use of KVConnector for dummy runs.
self.kv_connector.set_disabled(True)

# For non-first PP ranks, create dummy intermediate_tensors.
# Get the intermediate tensors for the dummy run.
intermediate_tensors = None
if not self.is_first_pp_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)
assert self.intermediate_tensors is not None
intermediate_tensors = self.intermediate_tensors[:num_tokens]

# Execute the model.
self.execute_model(
Expand Down Expand Up @@ -528,14 +538,6 @@ def capture_model(self) -> int:
)
return 0

# TODO (zhanqiu): support CUDA graph for PP.
if self.use_pp:
logger.warning_once(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only.",
)
return 0

start_time = time.perf_counter()
gc.collect()
torch.accelerator.empty_cache()
Expand All @@ -546,6 +548,7 @@ def capture_model(self) -> int:
self.model,
self.model_state,
self.input_buffers,
self.intermediate_tensors,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
Expand Down Expand Up @@ -1010,7 +1013,6 @@ def execute_model(
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
"inputs_embeds": inputs_embeds,
"intermediate_tensors": intermediate_tensors,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**self.model_state.prepare_inputs(input_batch, self.req_states),
Expand All @@ -1019,7 +1021,19 @@ def execute_model(
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None

# Prepare the intermediate tensors.
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
n = input_batch.num_tokens_after_padding
intermediate_tensors = IntermediateTensors(
{
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
for k, v in self.intermediate_tensors.tensors.items()
},
intermediate_tensors.kv_connector_output,
)
model_inputs["intermediate_tensors"] = intermediate_tensors

# Run model.
if batch_desc.cg_mode == CUDAGraphMode.FULL:
Expand All @@ -1028,11 +1042,6 @@ def execute_model(
# because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
else:
# For piecewise and eager mode, just call model().
batch_descriptor = BatchDescriptor(
Expand All @@ -1052,11 +1061,21 @@ def execute_model(
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None

if self.is_last_pp_rank:
if self.use_aux_hidden_state_outputs:
assert isinstance(model_output, tuple)
hidden_states, aux_hidden_states = model_output
else:
assert isinstance(model_output, torch.Tensor)
hidden_states = model_output
aux_hidden_states = None
output_intermediate_tensors = None
else:
assert isinstance(model_output, IntermediateTensors)
hidden_states = None
aux_hidden_states = None
output_intermediate_tensors = model_output

kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = ExecuteModelState(
Expand All @@ -1071,11 +1090,9 @@ def execute_model(

if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert isinstance(hidden_states, torch.Tensor)
assert output_intermediate_tensors is not None
output_intermediate_tensors.kv_connector_output = kv_connector_output
return output_intermediate_tensors
return None

@torch.inference_mode()
Expand Down Expand Up @@ -1259,7 +1276,7 @@ class ExecuteModelState(NamedTuple):
input_batch: InputBatch
attn_metadata: dict[str, Any] | None
slot_mappings_by_layer: dict[str, torch.Tensor] | None
hidden_states: torch.Tensor | IntermediateTensors
hidden_states: torch.Tensor | None
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None
Loading