Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion tensorrt_llm/_torch/compilation/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tensorrt_llm
from tensorrt_llm import logger

from .multi_stream.auto_multi_stream import multi_stream_schedule
from .patterns.ar_residual_norm import register_ar_residual_norm
from .patterns.residual_add_norm import register_add_norm
from .patterns.ub_allreduce import register_ub_patterns
Expand All @@ -25,12 +26,20 @@ class Backend:
_custom_pass_instances: List[PatternMatcherPass] = None
_graph_pool_handle: tuple[int, int] = None

# Following classes are used to let weakref ref the stream and eventlist objects.
class Streams(list):
pass

class Events(list):
pass

def __init__(
self,
enable_inductor=True,
enable_userbuffers=False,
enable_piecewise_cuda_graph: bool = False,
cuda_graph_batch_sizes: Optional[List[int]] = None,
max_num_streams: int = 1,
) -> None:
super().__init__()
self.elapsed_time = 0
Expand All @@ -45,6 +54,10 @@ def __init__(
else [])
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
self.no_optimization = False
# We only need to create aux streams.
self.aux_streams = Backend.Streams(
[torch.cuda.Stream() for i in range(max_num_streams - 1)])
self.events = Backend.Events()
inductor_config.enable_auto_functionalized_v2 = False

if Backend._graph_pool_handle is None:
Expand Down Expand Up @@ -77,6 +90,12 @@ def bypass_optimization(self):
def enable_optimization(self):
self.no_optimization = False

def generate_events(self, num_events: int):
if num_events > len(self.events):
self.events += [
torch.cuda.Event() for _ in range(num_events - len(self.events))
]

def optimize(
self,
gm: GraphModule,
Expand All @@ -90,17 +109,30 @@ def optimize(
graph.eliminate_dead_code()
# After this pass, cannot run any dce!!!
remove_copy_for_mutates_args(graph)

# Do not apply multi-stream if enable piecewise cuda graph or inductor
# For piecewise cuda graph, we will apply the multi-stream optimization in piecewise_optimizer
# For inductor, we do not control the passes inside inductor.
if len(
self.aux_streams
) > 0 and not self.piecewise_cuda_graph and not self.enable_inductor:
num_events = multi_stream_schedule(gm, len(self.aux_streams) + 1)
self.generate_events(num_events)

gm.recompile()

if self.piecewise_cuda_graph:
return piecewise_optimizer(
gm, num_events = piecewise_optimizer(
gm,
example_inputs,
self.enable_inductor,
self.input_num_tokens,
self.cuda_graph_batch_sizes,
self._graph_pool_handle,
len(self.aux_streams) + 1,
)
self.generate_events(num_events)
return gm
elif self.enable_inductor:
return compile_fx(gm, example_inputs)
else:
Expand Down
Empty file.
Loading