Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 66 additions & 18 deletions vllm/v1/worker/gpu/cudagraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.distributed.parallel_state import (
get_pp_group,
graph_capture,
is_global_first_rank,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables
Expand Down Expand Up @@ -87,7 +92,15 @@ def __init__(
assert self.compilation_config is not None
self.cudagraph_mode = cudagraph_mode
self.decode_query_len = decode_query_len

self.dp_size = vllm_config.parallel_config.data_parallel_size
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
if self.pp_size > 1:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True

self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
Expand Down Expand Up @@ -263,15 +276,18 @@ def __init__(
decode_query_len: int,
):
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
# Used for FULL CUDA graphs. PW CUDA graphs do not use these.
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
self.use_aux_hidden_state_outputs = False
self.intermediate_tensors: IntermediateTensors | None = None

def capture(
self,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
intermediate_tensors: IntermediateTensors | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
Expand All @@ -292,6 +308,12 @@ def create_forward_fn(
if self.dp_size > 1
else None
)
if not self.is_first_pp_rank:
assert intermediate_tensors is not None
intermediate_tensors_sliced = intermediate_tensors[:num_tokens]
else:
intermediate_tensors_sliced = None
Comment on lines +311 to +315
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The intermediate_tensors_sliced variable is assigned None when self.is_first_pp_rank is true. However, it is used unconditionally in the model_inputs dictionary on line 345. This could lead to a NameError if the model expects intermediate_tensors to always be present.

                intermediate_tensors_sliced = None
            else:
                intermediate_tensors_sliced = None

            model_inputs = {
                "input_ids": input_buffers.input_ids[:num_tokens],
                "positions": input_buffers.positions[:num_tokens],
                "intermediate_tensors": intermediate_tensors_sliced if intermediate_tensors_sliced is not None else None,
                **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
            }


attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
Expand Down Expand Up @@ -320,36 +342,62 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None:
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
# TODO: Pass intermediate_tensors for PP CUDA graph
# support (https://github.com/vllm-project/vllm/pull/35162).
"intermediate_tensors": None,
"intermediate_tensors": intermediate_tensors_sliced,
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
model_output = model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output

if cg_mode == CUDAGraphMode.PIECEWISE:
# PW CUDA graph internally handles the model outputs.
# No need to keep track of the hidden states.
return None

if self.is_last_pp_rank:
# Last PP rank (common case).
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = []
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
self.hidden_states[:num_tokens] = hidden_states
if (
self.use_aux_hidden_state_outputs
and not self.aux_hidden_states
):
self.aux_hidden_states = [
torch.empty_like(x) for x in aux_hidden_states
]
for i, aux in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux
else:
hidden_states = model_output
aux_hidden_states = []
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [
torch.empty_like(x) for x in aux_hidden_states
]
self.hidden_states[:num_tokens] = hidden_states
for i, aux in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux
# Non-last PP rank.
intermediate_tensors = model_output
assert isinstance(intermediate_tensors, IntermediateTensors)
if self.intermediate_tensors is None:
self.intermediate_tensors = IntermediateTensors(
{
k: torch.empty_like(v)
for k, v in intermediate_tensors.tensors.items()
}
)
for k, v in intermediate_tensors.tensors.items():
self.intermediate_tensors[k][:num_tokens] = v

return forward_fn

super().capture(create_forward_fn, progress_bar_desc)

def run_fullgraph(
self, desc: BatchExecutionDescriptor
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]] | IntermediateTensors:
"""Replay a captured FULL cudagraph and return hidden states."""
super().run_fullgraph(desc)
if not self.is_last_pp_rank:
assert self.intermediate_tensors is not None
return self.intermediate_tensors[: desc.num_tokens]

assert self.hidden_states is not None
hidden_states = self.hidden_states[: desc.num_tokens]
if not self.use_aux_hidden_state_outputs:
Expand Down
79 changes: 48 additions & 31 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Persistent buffer for intermediate tensors (non-first PP ranks).
self.intermediate_tensors: IntermediateTensors | None = None

# Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size
Expand Down Expand Up @@ -301,6 +303,17 @@ def load_model(self, *args, **kwargs) -> None:
if self.is_pooling_model and self.is_last_pp_rank:
self.pooling_runner = PoolingRunner(self.model)

if not self.is_first_pp_rank:
# For non-first PP ranks, create intermediate tensors sized
# for the max capture size so they can be sliced per batch.
# Save as persistent member so runtime can copy received data
# into the same addresses that the CUDA graphs captured.
self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)

def get_model(self) -> nn.Module:
return self.model

Expand Down Expand Up @@ -396,14 +409,11 @@ def _dummy_run(
# Disable any use of KVConnector for dummy runs.
self.kv_connector.set_disabled(True)

# For non-first PP ranks, create dummy intermediate_tensors.
# Get the intermediate tensors for the dummy run.
intermediate_tensors = None
if not self.is_first_pp_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)
assert self.intermediate_tensors is not None
intermediate_tensors = self.intermediate_tensors[:num_tokens]

# Execute the model.
self.execute_model(
Expand Down Expand Up @@ -528,14 +538,6 @@ def capture_model(self) -> int:
)
return 0

# TODO (zhanqiu): support CUDA graph for PP.
if self.use_pp:
logger.warning_once(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only.",
)
return 0

start_time = time.perf_counter()
gc.collect()
torch.accelerator.empty_cache()
Expand All @@ -546,6 +548,7 @@ def capture_model(self) -> int:
self.model,
self.model_state,
self.input_buffers,
self.intermediate_tensors,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
Expand Down Expand Up @@ -1010,7 +1013,6 @@ def execute_model(
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
"inputs_embeds": inputs_embeds,
"intermediate_tensors": intermediate_tensors,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**self.model_state.prepare_inputs(input_batch, self.req_states),
Expand All @@ -1019,7 +1021,19 @@ def execute_model(
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None

# Prepare the intermediate tensors.
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
n = input_batch.num_tokens_after_padding
intermediate_tensors = IntermediateTensors(
{
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
for k, v in self.intermediate_tensors.tensors.items()
},
intermediate_tensors.kv_connector_output,
)
model_inputs["intermediate_tensors"] = intermediate_tensors
Comment on lines 1026 to +1036
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code copies data from intermediate_tensors to self.intermediate_tensors using copy_. This operation is in-place and modifies self.intermediate_tensors. If the CUDA graph is replayed multiple times with different intermediate_tensors, the initial data in self.intermediate_tensors will be overwritten, potentially leading to incorrect results in subsequent iterations. This is especially problematic if the graph is captured with one set of intermediate tensors and replayed with another.

                    k: v[:n].clone() # Create a copy to avoid modifying the original tensor
                    for k, v in intermediate_tensors.tensors.items()
                },
                intermediate_tensors.kv_connector_output,


# Run model.
if batch_desc.cg_mode == CUDAGraphMode.FULL:
Expand All @@ -1028,11 +1042,6 @@ def execute_model(
# because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
else:
# For piecewise and eager mode, just call model().
batch_descriptor = BatchDescriptor(
Expand All @@ -1052,11 +1061,21 @@ def execute_model(
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None

if self.is_last_pp_rank:
if self.use_aux_hidden_state_outputs:
assert isinstance(model_output, tuple)
hidden_states, aux_hidden_states = model_output
else:
assert isinstance(model_output, torch.Tensor)
hidden_states = model_output
aux_hidden_states = None
output_intermediate_tensors = None
else:
assert isinstance(model_output, IntermediateTensors)
hidden_states = None
aux_hidden_states = None
output_intermediate_tensors = model_output

kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = ExecuteModelState(
Expand All @@ -1071,11 +1090,9 @@ def execute_model(

if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert isinstance(hidden_states, torch.Tensor)
assert output_intermediate_tensors is not None
output_intermediate_tensors.kv_connector_output = kv_connector_output
return output_intermediate_tensors
return None

@torch.inference_mode()
Expand Down Expand Up @@ -1259,7 +1276,7 @@ class ExecuteModelState(NamedTuple):
input_batch: InputBatch
attn_metadata: dict[str, Any] | None
slot_mappings_by_layer: dict[str, torch.Tensor] | None
hidden_states: torch.Tensor | IntermediateTensors
hidden_states: torch.Tensor | None
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None
Loading