Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 43 additions & 22 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,42 @@ def __init__(self, engine: "PyTorchModelEngine"):
self.max_beam_width = engine.max_beam_width
self.spec_config = engine.spec_config

self.max_possible_draft_len = (self.spec_config.max_draft_len
if self.enable_spec_decode else 0)

self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
self.static_inputs: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = {}
self.graph_outputs: Dict[Tuple[int, int],
Callable[[], Optional[torch.Tensor]]] = {}
self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {}
self.memory_pool = engine._cuda_graph_mem_pool
self.padding_dummy_request: Optional["Request"] = None

self.shared_static_tensors: Dict[str, torch.Tensor] = {}
if self.enabled:
self._create_shared_static_tensors()

def _create_shared_static_tensors(self):
"""Allocates static tensors sized for the largest possible batch."""
engine = self._get_engine()

token_per_request = self.max_possible_draft_len + 1
max_total_tokens = (self.max_supported_batch_size *
self.max_beam_width * token_per_request)
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)

self.shared_static_tensors = {
"input_ids":
torch.ones((max_total_tokens, ), device="cuda", dtype=torch.int32),
"position_ids":
torch.zeros((1, max_total_tokens), device="cuda",
dtype=torch.int32),
}
if engine.use_mrope:
self.shared_static_tensors["mrope_position_deltas"] = torch.zeros(
(self.max_supported_batch_size, 1),
device="cuda",
dtype=torch.int32)

@property
def enable_spec_decode(self):
return self._get_engine().is_spec_decode
Expand Down Expand Up @@ -139,38 +167,32 @@ def needs_capture(self, batch_size: int):
def capture(self, batch_size: int, forward_fn: Callable,
initial_inputs: Dict[str, Any]):
"""Captures the forward pass for a given batch size."""
engine = self._get_engine()
key = (batch_size, self.draft_len)
spec_metadata = initial_inputs.get("spec_metadata", None)
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
token_per_request = self.max_possible_draft_len + 1
num_tokens_for_capture = (batch_size * self.max_beam_width *
token_per_request)

static_tensors = {
sliced_static_tensors = {
"input_ids":
torch.ones((batch_size * self.max_beam_width * token_per_request, ),
device="cuda",
dtype=torch.int32),
self.shared_static_tensors["input_ids"][:num_tokens_for_capture],
"position_ids":
torch.zeros((
1,
batch_size * self.max_beam_width * token_per_request,
),
device="cuda",
dtype=torch.int32),
self.shared_static_tensors["position_ids"]
[:, :num_tokens_for_capture],
}
if engine.use_mrope:
static_tensors["mrope_position_deltas"] = torch.zeros(
(batch_size, 1), device="cuda", dtype=torch.int32)
self.static_inputs[key] = static_tensors
if "mrope_position_deltas" in self.shared_static_tensors:
sliced_static_tensors["mrope_position_deltas"] = \
self.shared_static_tensors["mrope_position_deltas"][:batch_size]

# Use the sliced tensors for capture
capture_inputs = initial_inputs.copy()
capture_inputs.update(static_tensors)
capture_inputs.update(sliced_static_tensors)

self.graph_metadata[key] = {
"attn_metadata": initial_inputs["attn_metadata"],
"spec_metadata": spec_metadata,
"spec_metadata": initial_inputs.get("spec_metadata", None),
}

# We have to do warm up runs to initialize PyTorch's
Expand Down Expand Up @@ -198,7 +220,7 @@ def replay(self, batch_size: int,
assert current_inputs.get(
"spec_metadata") is stored_meta["spec_metadata"]

static_tensors = self.static_inputs[key]
static_tensors = self.shared_static_tensors

input_ids = current_inputs["input_ids"]
seqlen = input_ids.shape[0]
Expand Down Expand Up @@ -301,7 +323,6 @@ def clear(self):
for graph in self.graphs.values():
graph.reset()
self.graphs.clear()
self.static_inputs.clear()
self.graph_outputs.clear()
self.graph_metadata.clear()
self.padding_dummy_request = None
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def __init__(
# the model engine.
self.attn_metadata = None
self.iter_states = {}
self._cuda_graphs = {}
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None

self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
Expand Down
1 change: 1 addition & 0 deletions tests/unittest/_torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def create_mock_engine(batch_size: int):
_cuda_graph_batch_sizes=[batch_size],
_max_cuda_graph_batch_size=batch_size,
max_beam_width=1,
max_num_tokens=8192,
is_spec_decode=False,
spec_config=None,
_cuda_graph_mem_pool=None,
Expand Down