55import torch
66import torch .nn as nn
77from torch .cuda import CUDAGraph
8- from torch .utils ._pytree import TreeSpec , tree_flatten
8+ from torch .fx ._pytree import tree_flatten_spec
9+ from torch .utils ._pytree import PyTree , TreeSpec , tree_flatten
910
1011from tensorrt_llm ._torch .autotuner import autotune
1112
1213from ...utils .cuda_graph import CudaGraphWarmUpPhase
1314from ...utils .logger import ad_logger
14- from ..compiler import BackendCompiler , BackendRegistry , _flatten_args
15+ from ..compiler import CompileBackendRegistry , CompilerBackend
16+
17+
18+ def _args_kwargs_flatten_spec (in_spec : TreeSpec , * args , ** kwargs ) -> List [Any ]:
19+ """Flatten inputs according to provided in_spec."""
20+ all_args : PyTree = (args , kwargs )
21+ return tree_flatten_spec (all_args , in_spec )
22+
23+
24+ def _args_kwargs_flatten (* args , ** kwargs ) -> Tuple [List [Any ], TreeSpec ]:
25+ """Flatten inputs and return flattened inputs together with the TreeSpec."""
26+ all_args : PyTree = (args , kwargs )
27+ return tree_flatten (all_args )
1528
1629
1730class CapturedGraph (nn .Module ):
1831 def __init__ (
1932 self ,
2033 model : nn .Module ,
21- in_spec : TreeSpec ,
22- out_spec : TreeSpec ,
2334 cuda_graph_batch_sizes : List [int ],
24- num_batched_inputs : Optional [ int ] = 1 , # number of batched, dynamic inputs...
35+ num_batched_inputs : int , # number of batched, dynamic inputs...
2536 ):
2637 super ().__init__ ()
27- self ._in_spec = in_spec
28- self ._out_spec = out_spec
2938 self .model = model
3039 self .cuda_graph_max_batch_size = max (cuda_graph_batch_sizes )
3140 ad_logger .info (f"Setting { self .cuda_graph_max_batch_size = } " )
3241 self .num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
33- self .graphs : Dict [Tuple [int , ...], CUDAGraph ] = {}
42+ self .cudagraphs : Dict [Tuple [int , ...], CUDAGraph ] = {}
3443 self ._input_buffers : List [torch .Tensor ] = [
3544 torch .empty (0 , 1 ) for _ in range (self .num_batched_inputs )
3645 ]
@@ -39,6 +48,10 @@ def __init__(
3948 self .cuda_graph_batch_sizes = sorted (cuda_graph_batch_sizes , reverse = True )
4049 self ._cuda_graph_mem_pool = None
4150
51+ # store the in_spec and out_spec during graph capture
52+ self ._in_spec = None
53+ self ._out_spec = None
54+
4255 def _get_hash (self , flat_args : List [Any ]) -> Tuple [int , ...]:
4356 return tuple (hash (a ) for a in flat_args )
4457
@@ -67,8 +80,7 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
6780 # compute output
6881 out = self .model (* args , ** kwargs )
6982 # write out into output buffer up to out batch size
70- out_flat , out_spec = tree_flatten (out )
71- assert out_spec == self ._out_spec , "Output spec mismatch."
83+ out_flat = tree_flatten_spec (out , self ._out_spec )
7284 for o_buffer , o in zip (self ._out_buffer_flat , out_flat ):
7385 o_buffer [: o .shape [0 ]] = o
7486 torch .cuda .synchronize ()
@@ -77,8 +89,11 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
7789
7890 def capture_graph (self , * args , ** kwargs ):
7991 """Capture and pre-fetch the graph for variable batch size."""
80- # flatten args, kwargs
81- all_args_flat = _flatten_args (self ._in_spec , * args , ** kwargs )
92+ # check this is the first time we capture the graph
93+ assert not self .cudagraphs , "Graphs already captured."
94+
95+ # flatten args, kwargs for the first time and record in_spec
96+ all_args_flat , self ._in_spec = _args_kwargs_flatten (* args , ** kwargs )
8297
8398 # extract the batched input tensors
8499 args_batched = all_args_flat [: self .num_batched_inputs ]
@@ -96,10 +111,8 @@ def capture_graph(self, *args, **kwargs):
96111 f"than the max_batch_size? It will fall back to non-CUDA graph forward pass for "
97112 f"batch sizes exceeding the max_batch_size."
98113 )
99- msg_ndim = "Expecting at least a 2D for batched input tensors."
100114 if any (self .cuda_graph_max_batch_size < input .shape [0 ] for input in args_batched ):
101115 ad_logger .info (msg_bs )
102- assert all (input .ndim > 1 for input in args_batched ), msg_ndim
103116
104117 # repeat the batched input tensors to the cuda_graph_max_batch_size
105118 self ._input_buffers = [
@@ -111,11 +124,11 @@ def capture_graph(self, *args, **kwargs):
111124 args , kwargs = self ._in_spec .unflatten (self ._input_buffers + args_static )
112125
113126 # capture output once with cuda_graph_max_batch_size to capture output buffers
127+ # store the out_spec at this point
114128 with CudaGraphWarmUpPhase ():
115129 ad_logger .info (f"Warm up with { self .cuda_graph_max_batch_size = } before graph capture" )
116130 out = self .model (* args , ** kwargs )
117- self ._out_buffer_flat , out_spec = tree_flatten (out )
118- assert out_spec == self ._out_spec , "Output spec mismatch."
131+ self ._out_buffer_flat , self ._out_spec = tree_flatten (out )
119132
120133 # capture graph now for a range of batch sizes
121134 for bs in self .cuda_graph_batch_sizes :
@@ -132,7 +145,7 @@ def capture_graph(self, *args, **kwargs):
132145 def forward (self , * args , ** kwargs ) -> Any :
133146 """Run the compiled graph."""
134147 # flatten args, kwargs
135- all_args_flat = _flatten_args (self ._in_spec , * args , ** kwargs )
148+ all_args_flat = _args_kwargs_flatten_spec (self ._in_spec , * args , ** kwargs )
136149
137150 # extract the batched input tensors
138151 args_batched = all_args_flat [: self .num_batched_inputs ]
@@ -150,30 +163,44 @@ def forward(self, *args, **kwargs) -> Any:
150163 combined_shape = sum (rounded_shapes , start = ())
151164
152165 # regular forward for non-matching shapes
153- if combined_shape not in self .graphs :
166+ if combined_shape not in self .cudagraphs :
154167 return self .model (* args , ** kwargs )
155168
156169 # copy inputs to input buffers
157170 for i , input_tensor in enumerate (args_batched ):
158171 self ._input_buffers [i ][: input_tensor .shape [0 ]].copy_ (input_tensor , non_blocking = True )
159172
160173 # run forward pass via graph
161- self .graphs [combined_shape ].replay ()
174+ self .cudagraphs [combined_shape ].replay ()
162175
163176 # retrieve output from buffer, cut to batch size, and unflatten
164177 bs = args_batched [0 ].shape [0 ]
165178 out_flat = [o_b [:bs ].detach ().clone () for o_b in self ._out_buffer_flat ]
166179 return self ._out_spec .unflatten (out_flat )
167180
168181
169- @BackendRegistry .register ("torch-cudagraph" )
170- class TorchCudagraphCompiler (BackendCompiler ):
182+ @CompileBackendRegistry .register ("torch-cudagraph" )
183+ class TorchCudagraphCompiler (CompilerBackend ):
171184 """Compiler that uses only CUDA graphs."""
172185
173- def __init__ (self , * args , ** kwargs ):
174- super ().__init__ (* args , ** kwargs )
175- requested = self .compiler_kwargs .get ("cuda_graph_batch_sizes" )
176- if not requested :
186+ def __init__ (
187+ self ,
188+ * args_for_init ,
189+ cuda_graph_batch_sizes : Optional [List [int ]] = None ,
190+ num_batched_inputs : int = 1 ,
191+ max_batch_size : Optional [int ] = None ,
192+ ** kwargs_for_init ,
193+ ):
194+ super ().__init__ (* args_for_init , ** kwargs_for_init )
195+
196+ # heuristic to identify max batch size
197+ assert max_batch_size or cuda_graph_batch_sizes , (
198+ "At least one of max_batch_size or cuda_graph_batch_sizes must be provided."
199+ )
200+ self .max_batch_size = max_batch_size or max (cuda_graph_batch_sizes )
201+
202+ self .num_batched_inputs = num_batched_inputs
203+ if not cuda_graph_batch_sizes :
177204 # Use heuristic which includes commonly-used sizes like 1 and max_bs
178205 self .cuda_graph_batch_sizes = self ._get_graph_batch_sizes (self .max_batch_size )
179206 ad_logger .info (f"Using heuristic cuda_graph_batch_sizes: { self .cuda_graph_batch_sizes } " )
@@ -182,39 +209,34 @@ def __init__(self, *args, **kwargs):
182209 # No point capturing CUDA graphs for batch sizes larger than max_batch_size
183210 effective = {
184211 min (max (1 , int (b )), int (self .max_batch_size ))
185- for b in requested
212+ for b in cuda_graph_batch_sizes
186213 if isinstance (b , (int , float )) and b > 0
187214 }
188215 self .cuda_graph_batch_sizes = sorted (effective , reverse = True )
189216
190217 # Log if we clamped any values
191- original_values = [int (b ) for b in requested if isinstance (b , (int , float )) and b > 0 ]
218+ original_values = [
219+ int (b ) for b in cuda_graph_batch_sizes if isinstance (b , (int , float )) and b > 0
220+ ]
192221 clamped_values = [v for v in original_values if v > self .max_batch_size ]
193222 if clamped_values :
194223 ad_logger .info (
195224 f"Clamped CUDA graph batch sizes { clamped_values } to max_batch_size={ self .max_batch_size } "
196225 )
197226
198227 ad_logger .info (
199- f"Using explicit cuda_graph_batch_sizes: requested={ requested } "
228+ f"Using explicit cuda_graph_batch_sizes: requested={ cuda_graph_batch_sizes } "
200229 f" -> effective={ self .cuda_graph_batch_sizes } "
201230 f" (clamped to [1, { self .max_batch_size } ])"
202231 )
203232
204- def _init_captured_graph (
205- self , gm : nn .Module , in_spec : TreeSpec , out_spec : TreeSpec
206- ) -> CapturedGraph :
207- return CapturedGraph (
208- gm ,
209- in_spec = in_spec ,
210- out_spec = out_spec ,
211- cuda_graph_batch_sizes = self .cuda_graph_batch_sizes ,
212- num_batched_inputs = self .compiler_kwargs .get ("num_batched_inputs" ),
213- )
214-
215233 @torch .inference_mode ()
216234 def compile (self ) -> CapturedGraph :
217- captured_model = self ._init_captured_graph (self .gm , self .gm ._in_spec , self .gm ._out_spec )
235+ captured_model = CapturedGraph (
236+ self .model ,
237+ cuda_graph_batch_sizes = self .cuda_graph_batch_sizes ,
238+ num_batched_inputs = self .num_batched_inputs ,
239+ )
218240
219241 # try capturing cudagraph
220242 if self .args is not None or self .kwargs is not None :
0 commit comments