Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 78 additions & 85 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import bisect
import contextlib
import weakref
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

import torch
Expand All @@ -16,12 +16,35 @@
from .scheduler import ScheduledRequests

if TYPE_CHECKING:
from .model_engine import PyTorchModelEngine
from ..distributed import MPIDist
from ..mapping import Mapping
from ..speculative import DecodingBaseConfig

# A large prime number used for dummy request IDs to avoid collisions
CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1


@dataclass
class CUDAGraphRunnerConfig:
"""Configuration for the CUDAGraphRunner, passed from the ModelEngine."""
use_cuda_graph: bool
cuda_graph_padding_enabled: bool
cuda_graph_batch_sizes: list[int]
max_cuda_graph_batch_size: int
max_beam_width: int
max_num_tokens: int
spec_config: Optional["DecodingBaseConfig"]
cuda_graph_mem_pool: Any
use_mrope: bool
original_max_draft_len: int
is_draft_model: bool
enable_attention_dp: bool
batch_size: int
mapping: Optional["Mapping"]
dist: Optional["MPIDist"]
kv_cache_manager_key: Any


class CUDAGraphRunner:
"""
Manages the lifecycle and execution of CUDA graphs for the model engine.
Expand All @@ -32,23 +55,22 @@ class CUDAGraphRunner:
"""
WARMUP_STEPS = 2

def __init__(self, engine: "PyTorchModelEngine"):
self.engine_ref = weakref.ref(engine)
def __init__(self, config: CUDAGraphRunnerConfig):
self.config = config

# High-level configuration
config = engine.pytorch_backend_config
# High-level configuration from the config object
self.enabled = config.use_cuda_graph
self.padding_enabled = config.cuda_graph_padding_enabled
self.supported_batch_sizes = engine._cuda_graph_batch_sizes
self.max_supported_batch_size = engine._max_cuda_graph_batch_size
self.max_beam_width = engine.max_beam_width
self.spec_config = engine.spec_config
self.supported_batch_sizes = config.cuda_graph_batch_sizes
self.max_supported_batch_size = config.max_cuda_graph_batch_size
self.max_beam_width = config.max_beam_width
self.spec_config = config.spec_config

self.graphs: Dict[Tuple[int, int, int], torch.cuda.CUDAGraph] = {}
self.graph_outputs: Dict[Tuple[int, int, int],
Callable[[], Optional[torch.Tensor]]] = {}
self.graph_metadata: Dict[Tuple[int, int, int], Dict[str, Any]] = {}
self.memory_pool = engine._cuda_graph_mem_pool
self.memory_pool = config.cuda_graph_mem_pool
self.padding_dummy_request: Optional["Request"] = None

self.shared_static_tensors: Dict[str, torch.Tensor] = {}
Expand All @@ -58,12 +80,11 @@ def __init__(self, engine: "PyTorchModelEngine"):

def _create_shared_static_tensors(self):
"""Allocates static tensors sized for the largest possible batch."""
engine = self._get_engine()

token_per_request = self.max_possible_draft_len + 1
max_draft_len = self.config.original_max_draft_len if self.config.spec_config is not None else 0
token_per_request = max_draft_len + 1
max_total_tokens = (self.max_supported_batch_size *
self.max_beam_width * token_per_request)
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
max_total_tokens = min(max_total_tokens, self.config.max_num_tokens)

self.shared_static_tensors = {
"input_ids":
Expand All @@ -72,7 +93,7 @@ def _create_shared_static_tensors(self):
torch.zeros((1, max_total_tokens), device="cuda",
dtype=torch.int32),
}
if engine.use_mrope:
if self.config.use_mrope:
self.shared_static_tensors["position_ids"] = torch.zeros(
(3, 1, max_total_tokens), device="cuda", dtype=torch.int32)
self.shared_static_tensors["multimodal_params"] = [
Expand All @@ -86,55 +107,31 @@ def _create_shared_static_tensors(self):
}) for _ in range(max_total_tokens)
]

@property
def enable_spec_decode(self):
return self._get_engine().enable_spec_decode

@property
def max_possible_draft_len(self):
engine = self._get_engine()
return (engine.original_max_draft_len if self.enable_spec_decode else 0)

def get_graph_key(
self,
batch_size,
enable_spec_decode: bool,
spec_resource_manager: Optional[BaseResourceManager] = None):
engine = self._get_engine()
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
if self.config.is_draft_model and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
draft_len = self.config.original_max_draft_len if spec_resource_manager.is_first_draft else 0
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
else:
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
draft_len = self.spec_config.max_draft_len if enable_spec_decode else 0
key = (batch_size, draft_len, False)
return key

@property
def spec_metadata(self):
return self._get_engine().spec_metadata

@property
def draft_tokens_cuda(self):
return self._get_engine().draft_tokens_cuda

@property
def attn_metadata(self):
return self._get_engine().attn_metadata

def __del__(self):
self.clear()

def _get_engine(self) -> "PyTorchModelEngine":
"""Safely dereferences the weak reference to the engine."""
engine = self.engine_ref()
if engine is None:
raise RuntimeError(
"The parent PyTorchModelEngine has been garbage collected.")
return engine

def maybe_get_cuda_graph(
self,
batch: ScheduledRequests,
iter_counter: int,
enable_spec_decode: bool,
attn_metadata: Any,
spec_metadata: Optional[Any],
draft_tokens_cuda: torch.Tensor,
spec_resource_manager: Optional[BaseResourceManager] = None):
"""
Determines if the current batch can be run with a CUDA graph.
Expand All @@ -145,17 +142,14 @@ def maybe_get_cuda_graph(
- The spec_metadata for the graph, if applicable.
- The key for the graph.
"""
engine = self._get_engine()

# disable when doing statistic
if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter(
engine.iter_counter):
if ExpertStatistic.set_iter(iter_counter):
return False, None, None, None

can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = batch.batch_size
if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
all_can_graph_batch = engine.dist.tp_allgather(
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
all_can_graph_batch = self.config.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
is_all_gen_only = all(all_can_graph[0]
for all_can_graph in all_can_graph_batch)
Expand All @@ -168,7 +162,8 @@ def maybe_get_cuda_graph(

if not self.enabled or not can_run_cuda_graph:
return False, None, None, None
key = self.get_graph_key(batch_size, spec_resource_manager)
key = self.get_graph_key(batch_size, enable_spec_decode,
spec_resource_manager)

if key in self.graphs:
return True, self.graph_metadata[key][
Expand All @@ -178,29 +173,28 @@ def maybe_get_cuda_graph(
return False, None, None, None

num_sequences_in_batch = batch_size * self.max_beam_width
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
graph_attn_metadata = attn_metadata.create_cuda_graph_metadata(
num_sequences_in_batch, False, key[1], self.cuda_graph_meta_buffers)
assert attn_metadata.is_cuda_graph
assert graph_attn_metadata.is_cuda_graph

if self.enable_spec_decode:
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
if enable_spec_decode:
graph_spec_metadata = spec_metadata.create_cuda_graph_metadata(
num_sequences_in_batch)
spec_metadata.draft_tokens = self.draft_tokens_cuda
graph_spec_metadata.draft_tokens = draft_tokens_cuda
else:
spec_metadata = None
return True, attn_metadata, spec_metadata, key
graph_spec_metadata = None
return True, graph_attn_metadata, graph_spec_metadata, key

def needs_capture(self, key: Tuple[int, int, int]):

return key not in self.graph_outputs

def capture(self,
key: Tuple[int, int, int],
forward_fn: Callable,
initial_inputs: Dict[str, Any],
enable_spec_decode: bool = False,
postprocess_fn: Optional[Callable] = None):
"""Captures the forward pass for a given batch size."""
engine = self._get_engine()
batch_size = key[0]
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
Expand All @@ -217,7 +211,7 @@ def capture(self,
self.shared_static_tensors["position_ids"]
[:, :num_tokens_for_capture],
}
if engine.use_mrope:
if self.config.use_mrope:
sliced_static_tensors["position_ids"] = self.shared_static_tensors[
"position_ids"][:, :, :num_tokens_for_capture],
sliced_static_tensors[
Expand All @@ -235,12 +229,10 @@ def capture(self,
def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
forward_fn: Callable,
capture_inputs: Dict[str, Any]):
engine = self._get_engine()
# for the first inference of draft model, we need to set the use_spec_decoding to True when capture the graph for multiple runs.
is_first_draft = key[2]
needs_kv_cache_recompute = True if engine.enable_spec_decode and engine.spec_config.spec_dec_mode.needs_kv_cache_recompute(
needs_kv_cache_recompute = True if enable_spec_decode and self.config.spec_config.spec_dec_mode.needs_kv_cache_recompute(
) else False
if is_first_draft and engine.is_draft_model and needs_kv_cache_recompute:
if is_first_draft and self.config.is_draft_model and needs_kv_cache_recompute:
capture_inputs['attn_metadata'].use_spec_decoding = True
return forward_fn(capture_inputs)

Expand Down Expand Up @@ -268,7 +260,6 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
def replay(self, key: Tuple[int, int, int],
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
"""Replays a previously captured graph."""
engine = self._get_engine()
stored_meta = self.graph_metadata[key]
assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"]
if stored_meta["spec_metadata"] is not None:
Expand All @@ -282,7 +273,7 @@ def replay(self, key: Tuple[int, int, int],
static_tensors["input_ids"][:seqlen].copy_(input_ids)

position_ids = current_inputs["position_ids"]
if engine.use_mrope and current_inputs.get(
if self.config.use_mrope and current_inputs.get(
'multimodal_params') is not None:
static_tensors["position_ids"][:, :, :seqlen].copy_(position_ids)
for i, multimodal_param in enumerate(
Expand All @@ -302,16 +293,16 @@ def replay(self, key: Tuple[int, int, int],
return output_ref

def _get_padded_batch(self, batch: ScheduledRequests,
resource_manager: ResourceManager) -> int:
engine = self._get_engine()
resource_manager: ResourceManager,
runtime_draft_len: int) -> int:
kv_cache_manager = resource_manager.get_resource_manager(
engine.kv_cache_manager_key)
self.config.kv_cache_manager_key)
can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = batch.batch_size
new_batch_size = batch_size

if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
graph_batch_size = engine.dist.tp_allgather(
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
graph_batch_size = self.config.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
all_can_graph = all(graph_batch[0]
for graph_batch in graph_batch_size)
Expand All @@ -329,7 +320,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
return 0

padding_size = padded_batch_size - batch_size
if padding_size + batch.batch_size > engine.batch_size:
if padding_size + batch.batch_size > self.config.batch_size:
return 0

# No padding if it would create too many concurrent requests.
Expand All @@ -344,9 +335,9 @@ def _get_padded_batch(self, batch: ScheduledRequests,
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
[CUDA_GRAPH_DUMMY_REQUEST_ID],
is_gen=True,
max_num_draft_tokens=engine.runtime_draft_len,
use_mrope=engine.use_mrope,
max_beam_width=engine.max_beam_width)[0]
max_num_draft_tokens=runtime_draft_len,
use_mrope=self.config.use_mrope,
max_beam_width=self.config.max_beam_width)[0]
self.padding_dummy_request.is_cuda_graph_dummy = True
spec_res_mgr = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
Expand All @@ -367,12 +358,14 @@ def _round_up_batch_size(self, batch_size: int) -> int:
return self.supported_batch_sizes[idx]

@contextlib.contextmanager
def pad_batch(self, scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager):
def pad_batch(self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
runtime_draft_len: int = 0):
"""Context manager to pad a batch to a graph-compatible size."""

padding_size = self._get_padded_batch(scheduled_requests,
resource_manager)
resource_manager,
runtime_draft_len)
try:
yield scheduled_requests
finally:
Expand Down
Loading
Loading