Skip to content

Commit 3af01fa

Browse files
committed
[None][chore] Make torch compiling and piecewise running flags thread-safe
Control skip_maybe_compile behavior directly Signed-off-by: Jonas Li <[email protected]>
1 parent 96cfdd8 commit 3af01fa

File tree

2 files changed

+113
-70
lines changed

2 files changed

+113
-70
lines changed

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..utils import (get_model_extra_attrs,
1515
get_per_request_piecewise_cuda_graph_flag,
1616
get_piecewise_cuda_graph_flag, make_weak_ref,
17-
set_piecewise_running)
17+
skip_maybe_compile)
1818
from .multi_stream.auto_multi_stream import multi_stream_schedule
1919
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
2020

@@ -171,68 +171,75 @@ def __call__(self, *args):
171171
or not get_per_request_piecewise_cuda_graph_flag()):
172172
return self.default_callable(*args)
173173

174-
if self.is_first_runner or self.is_last_runner:
175-
if self.is_first_runner == self.is_last_runner:
176-
set_piecewise_running(False)
177-
else:
178-
set_piecewise_running(self.is_first_runner)
179-
180-
entry = self.entries[runtime_num_of_token]
181-
182-
if entry.enable_inductor and not entry.compiled:
183-
entry.callable = compile_fx(entry.callable, args)
184-
entry.compiled = True
185-
186-
if entry.cuda_graph is None:
187-
188-
if not get_capture_piecewise_cuda_graph_flag():
189-
return entry.callable(*args)
190-
191-
if entry.warmup_count < 3:
192-
entry.warmup_count += 1
193-
return entry.callable(*args)
194-
195-
entry.input_addresses = [
196-
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
197-
]
198-
199-
graph = torch.cuda.CUDAGraph()
200-
201-
# Torch's cuda graph will call gc.collect() internally. This will slow down the performance.
202-
# We patch it to do nothing.
203-
with patch("gc.collect", lambda: None):
204-
# TODO: consider to use `make_graphed_callables()` when
205-
# it's ready rather than capture it ourselves
206-
# Graph Capture would override the stream. We need to setup the stream correctly.
207-
extra_attrs = get_model_extra_attrs()
208-
with torch.cuda.graph(graph, pool=self.graph_pool_handle):
174+
# Determine if we should skip compilation in @maybe_compile decorated functions:
175+
# - First runner only: skip compilation (to avoid overhead)
176+
# - Last runner only: skip compilation (to avoid overhead)
177+
# - Both first and last (single runner): allow compilation (normal mode)
178+
# - Middle runner: allow compilation (normal mode)
179+
should_skip = (self.is_first_runner or self.is_last_runner) and \
180+
not (self.is_first_runner and self.is_last_runner)
181+
182+
# Use context manager to directly control @maybe_compile behavior
183+
# This makes the relationship explicit: PiecewiseRunner → skip_maybe_compile → @maybe_compile
184+
with skip_maybe_compile(should_skip):
185+
entry = self.entries[runtime_num_of_token]
186+
187+
if entry.enable_inductor and not entry.compiled:
188+
entry.callable = compile_fx(entry.callable, args)
189+
entry.compiled = True
190+
191+
if entry.cuda_graph is None:
192+
193+
if not get_capture_piecewise_cuda_graph_flag():
194+
return entry.callable(*args)
195+
196+
if entry.warmup_count < 3:
197+
entry.warmup_count += 1
198+
return entry.callable(*args)
199+
200+
entry.input_addresses = [
201+
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
202+
]
203+
204+
graph = torch.cuda.CUDAGraph()
205+
206+
# Torch's cuda graph will call gc.collect() internally. This will slow down the performance.
207+
# We patch it to do nothing.
208+
with patch("gc.collect", lambda: None):
209+
# TODO: consider to use `make_graphed_callables()` when
210+
# it's ready rather than capture it ourselves
211+
# Graph Capture would override the stream. We need to setup the stream correctly.
212+
extra_attrs = get_model_extra_attrs()
213+
with torch.cuda.graph(graph, pool=self.graph_pool_handle):
214+
extra_attrs[
215+
"global_stream"] = torch.cuda.current_stream()
216+
output = entry.callable(*args)
209217
extra_attrs["global_stream"] = torch.cuda.current_stream()
210-
output = entry.callable(*args)
211-
extra_attrs["global_stream"] = torch.cuda.current_stream()
212218

213-
entry.cuda_graph = graph
214-
# Mark weak ref here. The intermediate activation tensor should be freed properly.
215-
# Here we don't use python native weakref since we still need the object to be alive when the graph is replayed.
216-
entry.output = make_weak_ref(output)
217-
entry.output_addresses = [
218-
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
219-
]
219+
entry.cuda_graph = graph
220+
# Mark weak ref here. The intermediate activation tensor should be freed properly.
221+
# Here we don't use python native weakref since we still need the object to be alive when the graph is replayed.
222+
entry.output = make_weak_ref(output)
223+
entry.output_addresses = [
224+
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
225+
]
220226

221-
entry.cuda_graph.replay()
227+
entry.cuda_graph.replay()
222228

223-
return output
229+
return output
224230

225-
if enable_llm_debug():
226-
runtime_input_addresses = [
227-
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
228-
]
231+
if enable_llm_debug():
232+
runtime_input_addresses = [
233+
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
234+
]
229235

230-
assert (entry.input_addresses == runtime_input_addresses
231-
), f"{entry.input_addresses} vs\n {runtime_input_addresses}"
236+
assert (
237+
entry.input_addresses == runtime_input_addresses
238+
), f"{entry.input_addresses} vs\n {runtime_input_addresses}"
232239

233-
entry.cuda_graph.replay()
240+
entry.cuda_graph.replay()
234241

235-
return entry.output
242+
return entry.output
236243

237244

238245
def piecewise_optimizer(

tensorrt_llm/_torch/utils.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from tensorrt_llm.math_utils import ceil_div, pad_up
1212
from tensorrt_llm.quantization.utils import fp4_utils
1313

14-
is_torch_compiling_flag = False
15-
is_piecewise_running_flag = False
14+
_torch_compiling = threading.local()
15+
# Controls whether @maybe_compile decorator should skip compilation
16+
# Set directly by PiecewiseRunner to avoid compilation overhead
17+
_skip_maybe_compile = threading.local()
1618

1719
aux_stream_name_list = [
1820
'Attention',
@@ -46,23 +48,42 @@ class ActivationType(IntEnum):
4648

4749

4850
def set_torch_compiling(enable: bool):
49-
global is_torch_compiling_flag
50-
is_torch_compiling_flag = enable
51+
_torch_compiling.flag = enable
5152

5253

5354
def is_torch_compiling() -> bool:
54-
global is_torch_compiling_flag
55-
return is_torch_compiling_flag
55+
return getattr(_torch_compiling, 'flag', False)
5656

5757

58-
def set_piecewise_running(enable: bool):
59-
global is_piecewise_running_flag
60-
is_piecewise_running_flag = enable
58+
@contextlib.contextmanager
59+
def skip_maybe_compile(skip: bool = True):
60+
"""
61+
Context manager to directly control @maybe_compile decorator behavior.
62+
63+
When skip=True, functions decorated with @maybe_compile will skip torch.compile
64+
to avoid compilation overhead. Used by PiecewiseRunner to control compilation.
65+
66+
This makes the relationship between PiecewiseRunner and @maybe_compile explicit.
67+
68+
Args:
69+
skip: Whether to skip compilation in @maybe_compile decorated functions
70+
71+
Example:
72+
with skip_maybe_compile(True):
73+
# Functions with @maybe_compile will NOT be compiled
74+
result = some_function()
75+
"""
76+
old_state = getattr(_skip_maybe_compile, 'flag', False)
77+
_skip_maybe_compile.flag = skip
78+
try:
79+
yield
80+
finally:
81+
_skip_maybe_compile.flag = old_state
6182

6283

63-
def is_piecewise_running() -> bool:
64-
global is_piecewise_running_flag
65-
return is_piecewise_running_flag
84+
def _should_skip_maybe_compile() -> bool:
85+
"""Check if @maybe_compile should skip compilation."""
86+
return getattr(_skip_maybe_compile, 'flag', False)
6687

6788

6889
_global_attrs = threading.local()
@@ -344,19 +365,34 @@ def get_device_uuid(device_idx: int) -> str:
344365
def maybe_compile(func=None, **compile_kwargs):
345366
"""
346367
Conditionally compile a function with torch.compile.
347-
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
368+
369+
Compilation is skipped when running within a skip_maybe_compile(True) context,
370+
which is used by PiecewiseRunner to avoid compilation overhead.
371+
348372
Args:
349373
func: The function to decorate (optional, for direct decoration).
350374
**compile_kwargs: Keyword arguments for torch.compile.
351375
Returns:
352-
The conditionally compiled function..
376+
The conditionally compiled function.
377+
378+
Example:
379+
@maybe_compile
380+
def my_function(x):
381+
return x * 2
382+
383+
# Normal usage: function is compiled
384+
result = my_function(tensor)
385+
386+
# With skip_maybe_compile: function runs uncompiled
387+
with skip_maybe_compile(True):
388+
result = my_function(tensor) # Not compiled
353389
"""
354390

355391
def decorator(f):
356392
compiled_func = torch.compile(f, **compile_kwargs)
357393

358394
def wrapper(*args, **kwargs):
359-
if is_piecewise_running():
395+
if _should_skip_maybe_compile():
360396
return f(*args, **kwargs)
361397
return compiled_func(*args, **kwargs)
362398

0 commit comments

Comments
 (0)