From 33566083324a27071dd73f27d8b0fd97bf81a9af Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Tue, 26 Aug 2025 09:17:44 +0800 Subject: [PATCH 1/3] share input_ids buffers among different cuda graphs Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- .../_torch/pyexecutor/cuda_graph_runner.py | 64 ++++++++++++------- .../_torch/pyexecutor/model_engine.py | 1 - 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 0007b99ebd2..3962ec66c11 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -40,14 +40,41 @@ def __init__(self, engine: "PyTorchModelEngine"): self.max_beam_width = engine.max_beam_width self.spec_config = engine.spec_config + self.max_possible_draft_len = (self.spec_config.max_draft_len + if self.enable_spec_decode else 0) + self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {} - self.static_inputs: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = {} self.graph_outputs: Dict[Tuple[int, int], Callable[[], Optional[torch.Tensor]]] = {} self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {} self.memory_pool = engine._cuda_graph_mem_pool self.padding_dummy_request: Optional["Request"] = None + self.shared_static_tensors: Dict[str, torch.Tensor] = {} + if self.enabled: + self._create_shared_static_tensors() + + def _create_shared_static_tensors(self): + """Allocates static tensors sized for the largest possible batch.""" + engine = self._get_engine() + + token_per_request = self.max_possible_draft_len + 1 + max_total_tokens = (self.max_supported_batch_size * + self.max_beam_width * token_per_request) + + self.shared_static_tensors = { + "input_ids": + torch.ones((max_total_tokens, ), device="cuda", dtype=torch.int32), + "position_ids": + torch.zeros((1, max_total_tokens), device="cuda", + dtype=torch.int32), + } + if engine.use_mrope: + self.shared_static_tensors["mrope_position_deltas"] = torch.zeros( + (self.max_supported_batch_size, 1), + device="cuda", + dtype=torch.int32) + @property def enable_spec_decode(self): return self._get_engine().is_spec_decode @@ -139,38 +166,32 @@ def needs_capture(self, batch_size: int): def capture(self, batch_size: int, forward_fn: Callable, initial_inputs: Dict[str, Any]): """Captures the forward pass for a given batch size.""" - engine = self._get_engine() key = (batch_size, self.draft_len) - spec_metadata = initial_inputs.get("spec_metadata", None) # [CUDA graph spec decode padding] # We pad input IDs/position IDs to the maximum draft length (token per request). # We're forced to do this because we cannot reallocate inputs over many graph runs. - token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1 + token_per_request = self.max_possible_draft_len + 1 + num_tokens_for_capture = (batch_size * self.max_beam_width * + token_per_request) - static_tensors = { + sliced_static_tensors = { "input_ids": - torch.ones((batch_size * self.max_beam_width * token_per_request, ), - device="cuda", - dtype=torch.int32), + self.shared_static_tensors["input_ids"][:num_tokens_for_capture], "position_ids": - torch.zeros(( - 1, - batch_size * self.max_beam_width * token_per_request, - ), - device="cuda", - dtype=torch.int32), + self.shared_static_tensors["position_ids"] + [:, :num_tokens_for_capture], } - if engine.use_mrope: - static_tensors["mrope_position_deltas"] = torch.zeros( - (batch_size, 1), device="cuda", dtype=torch.int32) - self.static_inputs[key] = static_tensors + if "mrope_position_deltas" in self.shared_static_tensors: + sliced_static_tensors["mrope_position_deltas"] = \ + self.shared_static_tensors["mrope_position_deltas"][:batch_size] + # Use the sliced tensors for capture capture_inputs = initial_inputs.copy() - capture_inputs.update(static_tensors) + capture_inputs.update(sliced_static_tensors) self.graph_metadata[key] = { "attn_metadata": initial_inputs["attn_metadata"], - "spec_metadata": spec_metadata, + "spec_metadata": initial_inputs.get("spec_metadata", None), } # We have to do warm up runs to initialize PyTorch's @@ -198,7 +219,7 @@ def replay(self, batch_size: int, assert current_inputs.get( "spec_metadata") is stored_meta["spec_metadata"] - static_tensors = self.static_inputs[key] + static_tensors = self.shared_static_tensors input_ids = current_inputs["input_ids"] seqlen = input_ids.shape[0] @@ -301,7 +322,6 @@ def clear(self): for graph in self.graphs.values(): graph.reset() self.graphs.clear() - self.static_inputs.clear() self.graph_outputs.clear() self.graph_metadata.clear() del self.memory_pool diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 1489fbe3593..4330ac5e8ba 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -419,7 +419,6 @@ def __init__( # the model engine. self.attn_metadata = None self.iter_states = {} - self._cuda_graphs = {} self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled From 93e8a7bf87b0ac53d82e4c17d76875d318ff8bcb Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 27 Aug 2025 18:19:49 +0800 Subject: [PATCH 2/3] fix comments Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 3962ec66c11..69ecca8dda7 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -61,6 +61,7 @@ def _create_shared_static_tensors(self): token_per_request = self.max_possible_draft_len + 1 max_total_tokens = (self.max_supported_batch_size * self.max_beam_width * token_per_request) + max_total_tokens = min(max_total_tokens, engine.max_num_tokens) self.shared_static_tensors = { "input_ids": From 2c7623e41495a34d14276f213dc9bab7cced3d9b Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Thu, 28 Aug 2025 15:25:58 +0800 Subject: [PATCH 3/3] fix ci Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tests/unittest/_torch/helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 86580f9b94a..a98d7d6cd7d 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -186,6 +186,7 @@ def create_mock_engine(batch_size: int): _cuda_graph_batch_sizes=[batch_size], _max_cuda_graph_batch_size=batch_size, max_beam_width=1, + max_num_tokens=8192, is_spec_decode=False, spec_config=None, _cuda_graph_mem_pool=None,