Skip to content

Commit ec5e2c2

Browse files
lucasliedominicshanshan
authored andcommitted
[None][feat] AutoDeploy: compiler backends based on nn.Module (NVIDIA#8126)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 7b2f0e6 commit ec5e2c2

File tree

16 files changed

+274
-194
lines changed

16 files changed

+274
-194
lines changed

examples/auto_deploy/build_and_run_flux.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from diffusers import DiffusionPipeline
77

8-
from tensorrt_llm._torch.auto_deploy.compile import compile_and_capture
8+
from tensorrt_llm._torch.auto_deploy.compile import CompileBackendRegistry
99
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
1010
from tensorrt_llm._torch.auto_deploy.transformations.library.fusion import fuse_gemms
1111
from tensorrt_llm._torch.auto_deploy.transformations.library.quantization import quantize
@@ -143,7 +143,8 @@ def main():
143143

144144
fuse_gemms(gm)
145145

146-
gm = compile_and_capture(gm, backend="torch-opt", args=(), kwargs=flux_kwargs)
146+
compiler_cls = CompileBackendRegistry.get("torch-opt")
147+
gm = compiler_cls(gm, args=(), kwargs=flux_kwargs).compile()
147148

148149
del model
149150
fx_model = gm

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_compile.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
77

8-
from ..compiler import BackendCompiler, BackendRegistry
8+
from ..compiler import CompileBackendRegistry, CompilerBackend
99

1010

11-
@BackendRegistry.register("torch-compile")
12-
class TorchCompileCompiler(BackendCompiler):
13-
def __init__(self, *args, **kwargs):
14-
super().__init__(*args, **kwargs)
11+
@CompileBackendRegistry.register("torch-compile")
12+
class TorchCompileCompiler(CompilerBackend):
13+
def __init__(self, *args_for_init, **kwargs_for_init):
14+
super().__init__(*args_for_init, **kwargs_for_init)
1515
ad_logger.info(f"Torch Dynamo cache size limit {torch._dynamo.config.cache_size_limit=}")
1616

1717
def compile(self) -> nn.Module:
1818
"""Compile the model using torch.compile."""
19-
return torch.compile(self.gm, dynamic=True)
19+
return torch.compile(self.model, dynamic=True)

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,41 @@
55
import torch
66
import torch.nn as nn
77
from torch.cuda import CUDAGraph
8-
from torch.utils._pytree import TreeSpec, tree_flatten
8+
from torch.fx._pytree import tree_flatten_spec
9+
from torch.utils._pytree import PyTree, TreeSpec, tree_flatten
910

1011
from tensorrt_llm._torch.autotuner import autotune
1112

1213
from ...utils.cuda_graph import CudaGraphWarmUpPhase
1314
from ...utils.logger import ad_logger
14-
from ..compiler import BackendCompiler, BackendRegistry, _flatten_args
15+
from ..compiler import CompileBackendRegistry, CompilerBackend
16+
17+
18+
def _args_kwargs_flatten_spec(in_spec: TreeSpec, *args, **kwargs) -> List[Any]:
19+
"""Flatten inputs according to provided in_spec."""
20+
all_args: PyTree = (args, kwargs)
21+
return tree_flatten_spec(all_args, in_spec)
22+
23+
24+
def _args_kwargs_flatten(*args, **kwargs) -> Tuple[List[Any], TreeSpec]:
25+
"""Flatten inputs and return flattened inputs together with the TreeSpec."""
26+
all_args: PyTree = (args, kwargs)
27+
return tree_flatten(all_args)
1528

1629

1730
class CapturedGraph(nn.Module):
1831
def __init__(
1932
self,
2033
model: nn.Module,
21-
in_spec: TreeSpec,
22-
out_spec: TreeSpec,
2334
cuda_graph_batch_sizes: List[int],
24-
num_batched_inputs: Optional[int] = 1, # number of batched, dynamic inputs...
35+
num_batched_inputs: int, # number of batched, dynamic inputs...
2536
):
2637
super().__init__()
27-
self._in_spec = in_spec
28-
self._out_spec = out_spec
2938
self.model = model
3039
self.cuda_graph_max_batch_size = max(cuda_graph_batch_sizes)
3140
ad_logger.info(f"Setting {self.cuda_graph_max_batch_size=}")
3241
self.num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
33-
self.graphs: Dict[Tuple[int, ...], CUDAGraph] = {}
42+
self.cudagraphs: Dict[Tuple[int, ...], CUDAGraph] = {}
3443
self._input_buffers: List[torch.Tensor] = [
3544
torch.empty(0, 1) for _ in range(self.num_batched_inputs)
3645
]
@@ -39,6 +48,10 @@ def __init__(
3948
self.cuda_graph_batch_sizes = sorted(cuda_graph_batch_sizes, reverse=True)
4049
self._cuda_graph_mem_pool = None
4150

51+
# store the in_spec and out_spec during graph capture
52+
self._in_spec = None
53+
self._out_spec = None
54+
4255
def _get_hash(self, flat_args: List[Any]) -> Tuple[int, ...]:
4356
return tuple(hash(a) for a in flat_args)
4457

@@ -67,8 +80,7 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
6780
# compute output
6881
out = self.model(*args, **kwargs)
6982
# write out into output buffer up to out batch size
70-
out_flat, out_spec = tree_flatten(out)
71-
assert out_spec == self._out_spec, "Output spec mismatch."
83+
out_flat = tree_flatten_spec(out, self._out_spec)
7284
for o_buffer, o in zip(self._out_buffer_flat, out_flat):
7385
o_buffer[: o.shape[0]] = o
7486
torch.cuda.synchronize()
@@ -77,8 +89,11 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
7789

7890
def capture_graph(self, *args, **kwargs):
7991
"""Capture and pre-fetch the graph for variable batch size."""
80-
# flatten args, kwargs
81-
all_args_flat = _flatten_args(self._in_spec, *args, **kwargs)
92+
# check this is the first time we capture the graph
93+
assert not self.cudagraphs, "Graphs already captured."
94+
95+
# flatten args, kwargs for the first time and record in_spec
96+
all_args_flat, self._in_spec = _args_kwargs_flatten(*args, **kwargs)
8297

8398
# extract the batched input tensors
8499
args_batched = all_args_flat[: self.num_batched_inputs]
@@ -96,10 +111,8 @@ def capture_graph(self, *args, **kwargs):
96111
f"than the max_batch_size? It will fall back to non-CUDA graph forward pass for "
97112
f"batch sizes exceeding the max_batch_size."
98113
)
99-
msg_ndim = "Expecting at least a 2D for batched input tensors."
100114
if any(self.cuda_graph_max_batch_size < input.shape[0] for input in args_batched):
101115
ad_logger.info(msg_bs)
102-
assert all(input.ndim > 1 for input in args_batched), msg_ndim
103116

104117
# repeat the batched input tensors to the cuda_graph_max_batch_size
105118
self._input_buffers = [
@@ -111,11 +124,11 @@ def capture_graph(self, *args, **kwargs):
111124
args, kwargs = self._in_spec.unflatten(self._input_buffers + args_static)
112125

113126
# capture output once with cuda_graph_max_batch_size to capture output buffers
127+
# store the out_spec at this point
114128
with CudaGraphWarmUpPhase():
115129
ad_logger.info(f"Warm up with {self.cuda_graph_max_batch_size=} before graph capture")
116130
out = self.model(*args, **kwargs)
117-
self._out_buffer_flat, out_spec = tree_flatten(out)
118-
assert out_spec == self._out_spec, "Output spec mismatch."
131+
self._out_buffer_flat, self._out_spec = tree_flatten(out)
119132

120133
# capture graph now for a range of batch sizes
121134
for bs in self.cuda_graph_batch_sizes:
@@ -132,7 +145,7 @@ def capture_graph(self, *args, **kwargs):
132145
def forward(self, *args, **kwargs) -> Any:
133146
"""Run the compiled graph."""
134147
# flatten args, kwargs
135-
all_args_flat = _flatten_args(self._in_spec, *args, **kwargs)
148+
all_args_flat = _args_kwargs_flatten_spec(self._in_spec, *args, **kwargs)
136149

137150
# extract the batched input tensors
138151
args_batched = all_args_flat[: self.num_batched_inputs]
@@ -150,30 +163,44 @@ def forward(self, *args, **kwargs) -> Any:
150163
combined_shape = sum(rounded_shapes, start=())
151164

152165
# regular forward for non-matching shapes
153-
if combined_shape not in self.graphs:
166+
if combined_shape not in self.cudagraphs:
154167
return self.model(*args, **kwargs)
155168

156169
# copy inputs to input buffers
157170
for i, input_tensor in enumerate(args_batched):
158171
self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True)
159172

160173
# run forward pass via graph
161-
self.graphs[combined_shape].replay()
174+
self.cudagraphs[combined_shape].replay()
162175

163176
# retrieve output from buffer, cut to batch size, and unflatten
164177
bs = args_batched[0].shape[0]
165178
out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat]
166179
return self._out_spec.unflatten(out_flat)
167180

168181

169-
@BackendRegistry.register("torch-cudagraph")
170-
class TorchCudagraphCompiler(BackendCompiler):
182+
@CompileBackendRegistry.register("torch-cudagraph")
183+
class TorchCudagraphCompiler(CompilerBackend):
171184
"""Compiler that uses only CUDA graphs."""
172185

173-
def __init__(self, *args, **kwargs):
174-
super().__init__(*args, **kwargs)
175-
requested = self.compiler_kwargs.get("cuda_graph_batch_sizes")
176-
if not requested:
186+
def __init__(
187+
self,
188+
*args_for_init,
189+
cuda_graph_batch_sizes: Optional[List[int]] = None,
190+
num_batched_inputs: int = 1,
191+
max_batch_size: Optional[int] = None,
192+
**kwargs_for_init,
193+
):
194+
super().__init__(*args_for_init, **kwargs_for_init)
195+
196+
# heuristic to identify max batch size
197+
assert max_batch_size or cuda_graph_batch_sizes, (
198+
"At least one of max_batch_size or cuda_graph_batch_sizes must be provided."
199+
)
200+
self.max_batch_size = max_batch_size or max(cuda_graph_batch_sizes)
201+
202+
self.num_batched_inputs = num_batched_inputs
203+
if not cuda_graph_batch_sizes:
177204
# Use heuristic which includes commonly-used sizes like 1 and max_bs
178205
self.cuda_graph_batch_sizes = self._get_graph_batch_sizes(self.max_batch_size)
179206
ad_logger.info(f"Using heuristic cuda_graph_batch_sizes: {self.cuda_graph_batch_sizes}")
@@ -182,39 +209,34 @@ def __init__(self, *args, **kwargs):
182209
# No point capturing CUDA graphs for batch sizes larger than max_batch_size
183210
effective = {
184211
min(max(1, int(b)), int(self.max_batch_size))
185-
for b in requested
212+
for b in cuda_graph_batch_sizes
186213
if isinstance(b, (int, float)) and b > 0
187214
}
188215
self.cuda_graph_batch_sizes = sorted(effective, reverse=True)
189216

190217
# Log if we clamped any values
191-
original_values = [int(b) for b in requested if isinstance(b, (int, float)) and b > 0]
218+
original_values = [
219+
int(b) for b in cuda_graph_batch_sizes if isinstance(b, (int, float)) and b > 0
220+
]
192221
clamped_values = [v for v in original_values if v > self.max_batch_size]
193222
if clamped_values:
194223
ad_logger.info(
195224
f"Clamped CUDA graph batch sizes {clamped_values} to max_batch_size={self.max_batch_size}"
196225
)
197226

198227
ad_logger.info(
199-
f"Using explicit cuda_graph_batch_sizes: requested={requested}"
228+
f"Using explicit cuda_graph_batch_sizes: requested={cuda_graph_batch_sizes}"
200229
f" -> effective={self.cuda_graph_batch_sizes}"
201230
f" (clamped to [1, {self.max_batch_size}])"
202231
)
203232

204-
def _init_captured_graph(
205-
self, gm: nn.Module, in_spec: TreeSpec, out_spec: TreeSpec
206-
) -> CapturedGraph:
207-
return CapturedGraph(
208-
gm,
209-
in_spec=in_spec,
210-
out_spec=out_spec,
211-
cuda_graph_batch_sizes=self.cuda_graph_batch_sizes,
212-
num_batched_inputs=self.compiler_kwargs.get("num_batched_inputs"),
213-
)
214-
215233
@torch.inference_mode()
216234
def compile(self) -> CapturedGraph:
217-
captured_model = self._init_captured_graph(self.gm, self.gm._in_spec, self.gm._out_spec)
235+
captured_model = CapturedGraph(
236+
self.model,
237+
cuda_graph_batch_sizes=self.cuda_graph_batch_sizes,
238+
num_batched_inputs=self.num_batched_inputs,
239+
)
218240

219241
# try capturing cudagraph
220242
if self.args is not None or self.kwargs is not None:
Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
"""Mixed backend with torch.compile + Cudagraph."""
22

33
import torch
4+
import torch.nn as nn
45

56
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
67

7-
from ..compiler import BackendRegistry
8-
from .torch_cudagraph import CapturedGraph, TorchCudagraphCompiler
8+
from ..compiler import CompileBackendRegistry
9+
from .torch_cudagraph import TorchCudagraphCompiler
910

1011

11-
@BackendRegistry.register("torch-opt")
12+
@CompileBackendRegistry.register("torch-opt")
1213
class TorchOptCompiler(TorchCudagraphCompiler):
1314
"""Compiler that uses both torch.compile and CUDA graphs."""
1415

15-
def __init__(self, *args, **kwargs):
16-
super().__init__(*args, **kwargs)
16+
def __init__(self, *args_for_init, **kwargs_for_init):
17+
super().__init__(*args_for_init, **kwargs_for_init)
1718
torch._dynamo.config.recompile_limit = max(
1819
len(self.cuda_graph_batch_sizes), torch._dynamo.config.recompile_limit
1920
)
@@ -22,6 +23,6 @@ def __init__(self, *args, **kwargs):
2223
f"{torch._dynamo.config.cache_size_limit=}"
2324
)
2425

25-
def _init_captured_graph(self, gm, in_spec, out_spec) -> CapturedGraph:
26-
gm = torch.compile(gm, dynamic=True)
27-
return super()._init_captured_graph(gm, in_spec, out_spec)
26+
def compile(self) -> nn.Module:
27+
self.model = torch.compile(self.model, dynamic=True)
28+
return super().compile()

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_simple.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import torch.nn as nn
44

5-
from ..compiler import BackendCompiler, BackendRegistry
5+
from ..compiler import CompileBackendRegistry, CompilerBackend
66

77

8-
@BackendRegistry.register("torch-simple")
9-
class TorchCompiler(BackendCompiler):
8+
@CompileBackendRegistry.register("torch-simple")
9+
class TorchCompiler(CompilerBackend):
1010
def compile(self) -> nn.Module:
11-
return self.gm
11+
return self.model

0 commit comments

Comments
 (0)