diff --git a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py index 67d83c3c43f..8f79761be69 100644 --- a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py +++ b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py @@ -14,7 +14,7 @@ from ..utils import (get_model_extra_attrs, get_per_request_piecewise_cuda_graph_flag, get_piecewise_cuda_graph_flag, make_weak_ref, - set_piecewise_running) + skip_maybe_compile) from .multi_stream.auto_multi_stream import multi_stream_schedule from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function @@ -171,68 +171,75 @@ def __call__(self, *args): or not get_per_request_piecewise_cuda_graph_flag()): return self.default_callable(*args) - if self.is_first_runner or self.is_last_runner: - if self.is_first_runner == self.is_last_runner: - set_piecewise_running(False) - else: - set_piecewise_running(self.is_first_runner) - - entry = self.entries[runtime_num_of_token] - - if entry.enable_inductor and not entry.compiled: - entry.callable = compile_fx(entry.callable, args) - entry.compiled = True - - if entry.cuda_graph is None: - - if not get_capture_piecewise_cuda_graph_flag(): - return entry.callable(*args) - - if entry.warmup_count < 3: - entry.warmup_count += 1 - return entry.callable(*args) - - entry.input_addresses = [ - i.data_ptr() for i in args if isinstance(i, torch.Tensor) - ] - - graph = torch.cuda.CUDAGraph() - - # Torch's cuda graph will call gc.collect() internally. This will slow down the performance. - # We patch it to do nothing. - with patch("gc.collect", lambda: None): - # TODO: consider to use `make_graphed_callables()` when - # it's ready rather than capture it ourselves - # Graph Capture would override the stream. We need to setup the stream correctly. - extra_attrs = get_model_extra_attrs() - with torch.cuda.graph(graph, pool=self.graph_pool_handle): + # Determine if we should skip compilation in @maybe_compile decorated functions: + # - First runner only: skip compilation (to avoid overhead) + # - Last runner only: skip compilation (to avoid overhead) + # - Both first and last (single runner): allow compilation (normal mode) + # - Middle runner: allow compilation (normal mode) + should_skip = (self.is_first_runner or self.is_last_runner) and \ + not (self.is_first_runner and self.is_last_runner) + + # Use context manager to directly control @maybe_compile behavior + # This makes the relationship explicit: PiecewiseRunner → skip_maybe_compile → @maybe_compile + with skip_maybe_compile(should_skip): + entry = self.entries[runtime_num_of_token] + + if entry.enable_inductor and not entry.compiled: + entry.callable = compile_fx(entry.callable, args) + entry.compiled = True + + if entry.cuda_graph is None: + + if not get_capture_piecewise_cuda_graph_flag(): + return entry.callable(*args) + + if entry.warmup_count < 3: + entry.warmup_count += 1 + return entry.callable(*args) + + entry.input_addresses = [ + i.data_ptr() for i in args if isinstance(i, torch.Tensor) + ] + + graph = torch.cuda.CUDAGraph() + + # Torch's cuda graph will call gc.collect() internally. This will slow down the performance. + # We patch it to do nothing. + with patch("gc.collect", lambda: None): + # TODO: consider to use `make_graphed_callables()` when + # it's ready rather than capture it ourselves + # Graph Capture would override the stream. We need to setup the stream correctly. + extra_attrs = get_model_extra_attrs() + with torch.cuda.graph(graph, pool=self.graph_pool_handle): + extra_attrs[ + "global_stream"] = torch.cuda.current_stream() + output = entry.callable(*args) extra_attrs["global_stream"] = torch.cuda.current_stream() - output = entry.callable(*args) - extra_attrs["global_stream"] = torch.cuda.current_stream() - entry.cuda_graph = graph - # Mark weak ref here. The intermediate activation tensor should be freed properly. - # Here we don't use python native weakref since we still need the object to be alive when the graph is replayed. - entry.output = make_weak_ref(output) - entry.output_addresses = [ - i.data_ptr() for i in output if isinstance(i, torch.Tensor) - ] + entry.cuda_graph = graph + # Mark weak ref here. The intermediate activation tensor should be freed properly. + # Here we don't use python native weakref since we still need the object to be alive when the graph is replayed. + entry.output = make_weak_ref(output) + entry.output_addresses = [ + i.data_ptr() for i in output if isinstance(i, torch.Tensor) + ] - entry.cuda_graph.replay() + entry.cuda_graph.replay() - return output + return output - if enable_llm_debug(): - runtime_input_addresses = [ - i.data_ptr() for i in args if isinstance(i, torch.Tensor) - ] + if enable_llm_debug(): + runtime_input_addresses = [ + i.data_ptr() for i in args if isinstance(i, torch.Tensor) + ] - assert (entry.input_addresses == runtime_input_addresses - ), f"{entry.input_addresses} vs\n {runtime_input_addresses}" + assert ( + entry.input_addresses == runtime_input_addresses + ), f"{entry.input_addresses} vs\n {runtime_input_addresses}" - entry.cuda_graph.replay() + entry.cuda_graph.replay() - return entry.output + return entry.output def piecewise_optimizer( diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 5beff19f710..242ebe0e2e8 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -11,8 +11,10 @@ from tensorrt_llm.math_utils import ceil_div, pad_up from tensorrt_llm.quantization.utils import fp4_utils -is_torch_compiling_flag = False -is_piecewise_running_flag = False +_torch_compiling = threading.local() +# Controls whether @maybe_compile decorator should skip compilation +# Set directly by PiecewiseRunner to avoid compilation overhead +_skip_maybe_compile = threading.local() aux_stream_name_list = [ 'Attention', @@ -46,23 +48,42 @@ class ActivationType(IntEnum): def set_torch_compiling(enable: bool): - global is_torch_compiling_flag - is_torch_compiling_flag = enable + _torch_compiling.flag = enable def is_torch_compiling() -> bool: - global is_torch_compiling_flag - return is_torch_compiling_flag + return getattr(_torch_compiling, 'flag', False) -def set_piecewise_running(enable: bool): - global is_piecewise_running_flag - is_piecewise_running_flag = enable +@contextlib.contextmanager +def skip_maybe_compile(skip: bool = True): + """ + Context manager to directly control @maybe_compile decorator behavior. + + When skip=True, functions decorated with @maybe_compile will skip torch.compile + to avoid compilation overhead. Used by PiecewiseRunner to control compilation. + + This makes the relationship between PiecewiseRunner and @maybe_compile explicit. + + Args: + skip: Whether to skip compilation in @maybe_compile decorated functions + + Example: + with skip_maybe_compile(True): + # Functions with @maybe_compile will NOT be compiled + result = some_function() + """ + old_state = getattr(_skip_maybe_compile, 'flag', False) + _skip_maybe_compile.flag = skip + try: + yield + finally: + _skip_maybe_compile.flag = old_state -def is_piecewise_running() -> bool: - global is_piecewise_running_flag - return is_piecewise_running_flag +def _should_skip_maybe_compile() -> bool: + """Check if @maybe_compile should skip compilation.""" + return getattr(_skip_maybe_compile, 'flag', False) _global_attrs = threading.local() @@ -344,19 +365,34 @@ def get_device_uuid(device_idx: int) -> str: def maybe_compile(func=None, **compile_kwargs): """ Conditionally compile a function with torch.compile. - If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op. + + Compilation is skipped when running within a skip_maybe_compile(True) context, + which is used by PiecewiseRunner to avoid compilation overhead. + Args: func: The function to decorate (optional, for direct decoration). **compile_kwargs: Keyword arguments for torch.compile. Returns: - The conditionally compiled function.. + The conditionally compiled function. + + Example: + @maybe_compile + def my_function(x): + return x * 2 + + # Normal usage: function is compiled + result = my_function(tensor) + + # With skip_maybe_compile: function runs uncompiled + with skip_maybe_compile(True): + result = my_function(tensor) # Not compiled """ def decorator(f): compiled_func = torch.compile(f, **compile_kwargs) def wrapper(*args, **kwargs): - if is_piecewise_running(): + if _should_skip_maybe_compile(): return f(*args, **kwargs) return compiled_func(*args, **kwargs)