|
11 | 11 | from tensorrt_llm.math_utils import ceil_div, pad_up |
12 | 12 | from tensorrt_llm.quantization.utils import fp4_utils |
13 | 13 |
|
14 | | -is_torch_compiling_flag = False |
15 | | -is_piecewise_running_flag = False |
| 14 | +_torch_compiling = threading.local() |
| 15 | +_piecewise_running = threading.local() |
16 | 16 |
|
17 | 17 | aux_stream_name_list = [ |
18 | 18 | 'Attention', |
@@ -46,23 +46,19 @@ class ActivationType(IntEnum): |
46 | 46 |
|
47 | 47 |
|
48 | 48 | def set_torch_compiling(enable: bool): |
49 | | - global is_torch_compiling_flag |
50 | | - is_torch_compiling_flag = enable |
| 49 | + _torch_compiling.flag = enable |
51 | 50 |
|
52 | 51 |
|
53 | 52 | def is_torch_compiling() -> bool: |
54 | | - global is_torch_compiling_flag |
55 | | - return is_torch_compiling_flag |
| 53 | + return getattr(_torch_compiling, 'flag', False) |
56 | 54 |
|
57 | 55 |
|
58 | 56 | def set_piecewise_running(enable: bool): |
59 | | - global is_piecewise_running_flag |
60 | | - is_piecewise_running_flag = enable |
| 57 | + _piecewise_running.flag = enable |
61 | 58 |
|
62 | 59 |
|
63 | 60 | def is_piecewise_running() -> bool: |
64 | | - global is_piecewise_running_flag |
65 | | - return is_piecewise_running_flag |
| 61 | + return getattr(_piecewise_running, 'flag', False) |
66 | 62 |
|
67 | 63 |
|
68 | 64 | _global_attrs = threading.local() |
|
0 commit comments