Skip to content

Commit f059da2

Browse files
committed
[None][chore] Make torch compiling and piecewise running flags thread-safe
Signed-off-by: Jonas Li <[email protected]>
1 parent 96cfdd8 commit f059da2

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

tensorrt_llm/_torch/utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from tensorrt_llm.math_utils import ceil_div, pad_up
1212
from tensorrt_llm.quantization.utils import fp4_utils
1313

14-
is_torch_compiling_flag = False
15-
is_piecewise_running_flag = False
14+
_torch_compiling = threading.local()
15+
_piecewise_running = threading.local()
1616

1717
aux_stream_name_list = [
1818
'Attention',
@@ -46,23 +46,19 @@ class ActivationType(IntEnum):
4646

4747

4848
def set_torch_compiling(enable: bool):
49-
global is_torch_compiling_flag
50-
is_torch_compiling_flag = enable
49+
_torch_compiling.flag = enable
5150

5251

5352
def is_torch_compiling() -> bool:
54-
global is_torch_compiling_flag
55-
return is_torch_compiling_flag
53+
return getattr(_torch_compiling, 'flag', False)
5654

5755

5856
def set_piecewise_running(enable: bool):
59-
global is_piecewise_running_flag
60-
is_piecewise_running_flag = enable
57+
_piecewise_running.flag = enable
6158

6259

6360
def is_piecewise_running() -> bool:
64-
global is_piecewise_running_flag
65-
return is_piecewise_running_flag
61+
return getattr(_piecewise_running, 'flag', False)
6662

6763

6864
_global_attrs = threading.local()

0 commit comments

Comments
 (0)