Skip to content

Commit b459dbd

Browse files
authored
[SOT] Using EventGuard in decorator form (#73891)
* Using EventGuard in decorator form * fix: import EventGuard, SotProfiler, event_register in paddle.jit.sot.profiler package
1 parent 0ca88c4 commit b459dbd

File tree

6 files changed

+69
-49
lines changed

6 files changed

+69
-49
lines changed

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from paddle.base.dygraph.base import switch_to_static_graph
3333
from paddle.pir import Value, fake_value, get_fake_value_name, is_fake_value
3434

35+
from ..profiler import event_register
3536
from .logging_utils import TranslatorLogger
3637
from .utils import (
3738
RETURN_NO_VALUE_MAGIC_NUM,
@@ -779,6 +780,7 @@ def __call__(self, inputs):
779780
restored_nest_out = self._restore_out(out)
780781
return self._remove_no_value(restored_nest_out)
781782

783+
@event_register("sot call partial_program")
782784
def sot_call(self, inputs):
783785
"""
784786
In sot, inputs and outputs of partial program only contain tensors, so we can skip some step to speed up

python/paddle/jit/dy2static/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from paddle.utils import flatten, gast
4949
from paddle.utils.environments import (
5050
BooleanEnvironmentVariable,
51+
IntegerEnvironmentVariable,
5152
)
5253

5354
from .ast_utils import ast_to_source_code
@@ -83,6 +84,7 @@
8384
core.VarDesc.VarType.FETCH_LIST,
8485
]
8586

87+
ENV_SOT_EVENT_LEVEL = IntegerEnvironmentVariable("SOT_EVENT_LEVEL", 0)
8688
ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True)
8789
ENV_ENABLE_CINN_IN_DY2ST = BooleanEnvironmentVariable(
8890
"ENABLE_CINN_IN_DY2ST", True

python/paddle/jit/sot/profiler/profiler.py renamed to python/paddle/jit/profiler.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
from contextlib import contextmanager
1618
from functools import wraps
19+
from typing import Callable, TypeVar
20+
21+
from typing_extensions import ParamSpec
1722

1823
from paddle.framework import core
1924

20-
from ..utils.envs import ENV_SOT_EVENT_LEVEL
25+
from .dy2static.utils import ENV_SOT_EVENT_LEVEL
26+
27+
P = ParamSpec("P")
28+
T = TypeVar("T")
2129

2230

2331
class SotProfiler:
@@ -48,16 +56,23 @@ def EventGuard(event_name, event_level=1):
4856
core.nvprof_nvtx_pop()
4957

5058

51-
def event_register(event_name, event_level=1):
52-
def event_wrapper(func):
59+
def event_register(
60+
event_name_formatter: Callable[P, str] | str, event_level=1
61+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
62+
def event_wrapper(func: Callable[P, T]) -> Callable[P, T]:
5363
@wraps(func)
54-
def call_with_event(*args, **kwargs):
64+
def call_with_event(*args: P.args, **kwargs: P.kwargs):
65+
event_name = (
66+
event_name_formatter(*args, **kwargs)
67+
if callable(event_name_formatter)
68+
else event_name_formatter
69+
)
5570
with EventGuard(event_name, event_level=event_level):
5671
return func(*args, **kwargs)
5772

5873
return call_with_event
5974

60-
def do_nothing(func):
75+
def do_nothing(func: Callable[P, T]) -> Callable[P, T]:
6176
return func
6277

6378
if ENV_SOT_EVENT_LEVEL.get() >= event_level:

python/paddle/jit/sot/profiler/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .kernel_stats import SotStepProfilerGuard as SotStepProfilerGuard
16-
from .profiler import (
15+
from paddle.jit.profiler import (
1716
EventGuard as EventGuard,
1817
SotProfiler as SotProfiler,
1918
event_register as event_register,
2019
)
20+
21+
from .kernel_stats import SotStepProfilerGuard as SotStepProfilerGuard

python/paddle/jit/sot/symbolic/compile_cache.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from typing import TYPE_CHECKING
1919

2020
import paddle
21+
from paddle.jit.profiler import EventGuard, event_register
2122

2223
from ..infer_meta import convert_meta_to_input_spec
23-
from ..profiler import EventGuard
2424
from ..utils import (
2525
ENV_SOT_EXPORT,
2626
Cache,
@@ -209,48 +209,49 @@ def update_compile_time_info(self, SIR, partial_program_layer):
209209
code
210210
] += partial_program_layer._compile_time_counter.get_total_time()
211211

212+
@event_register(
213+
lambda self, *args, **kwargs: f"FallbackWrapper: {self.SIR.name}"
214+
)
212215
def __call__(self, *args, **kwargs):
213-
with EventGuard(f"FallbackWrapper: {self.SIR.name}"):
214-
if StepInfoManager().need_back_trace:
215-
trace_back_frames()
216+
if StepInfoManager().need_back_trace:
217+
trace_back_frames()
216218

217-
log_do(
218-
2,
219-
lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR),
220-
)
221-
log_do(
222-
4,
223-
lambda: print(
224-
self.compiled_fn.get_concrete_program(*args, **kwargs)[
225-
1
226-
].train_program
227-
),
228-
)
229-
if self.partial_program is None:
230-
with EventGuard("FallbackWrapper: get_concrete_program"):
231-
(
232-
self.concrete_program,
233-
self.partial_program,
234-
) = self.compiled_fn.get_concrete_program(*args, **kwargs)
235-
self.partial_program.training = self.is_training
236-
with EventGuard("FallbackWrapper: sot call partial_program"):
237-
outputs = self.partial_program.sot_call(*args, **kwargs)
238-
239-
clear_eager_tensor_name(outputs)
240-
log_do(
241-
4,
242-
lambda: print("[CompileCache] run sir forward success."),
243-
)
244-
self.collect_new_symbol_hit_rate(args, outputs)
245-
self.collect_subgraph_relation(args, outputs, self.partial_program)
246-
self.collect_subgraph_info(self.concrete_program.main_program)
247-
self.update_compile_time_info(self.SIR, self.partial_program)
248-
if ENV_SOT_EXPORT.get() != "" and not self.exported:
249-
export(self.SIR, ENV_SOT_EXPORT.get())
250-
self.exported = True
251-
252-
self.is_first_call = False
253-
return outputs
219+
log_do(
220+
2,
221+
lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR),
222+
)
223+
log_do(
224+
4,
225+
lambda: print(
226+
self.compiled_fn.get_concrete_program(*args, **kwargs)[
227+
1
228+
].train_program
229+
),
230+
)
231+
if self.partial_program is None:
232+
with EventGuard("FallbackWrapper: get_concrete_program"):
233+
(
234+
self.concrete_program,
235+
self.partial_program,
236+
) = self.compiled_fn.get_concrete_program(*args, **kwargs)
237+
self.partial_program.training = self.is_training
238+
outputs = self.partial_program.sot_call(*args, **kwargs)
239+
240+
clear_eager_tensor_name(outputs)
241+
log_do(
242+
4,
243+
lambda: print("[CompileCache] run sir forward success."),
244+
)
245+
self.collect_new_symbol_hit_rate(args, outputs)
246+
self.collect_subgraph_relation(args, outputs, self.partial_program)
247+
self.collect_subgraph_info(self.concrete_program.main_program)
248+
self.update_compile_time_info(self.SIR, self.partial_program)
249+
if ENV_SOT_EXPORT.get() != "" and not self.exported:
250+
export(self.SIR, ENV_SOT_EXPORT.get())
251+
self.exported = True
252+
253+
self.is_first_call = False
254+
return outputs
254255

255256

256257
class CompileSIRCache(Cache, metaclass=Singleton):

python/paddle/jit/sot/utils/envs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def parse_parameterized_key(input_str: str) -> dict[str, list[str]]:
135135
"SOT_ENABLE_GUARD_TREE",
136136
False,
137137
)
138-
ENV_SOT_EVENT_LEVEL = IntegerEnvironmentVariable("SOT_EVENT_LEVEL", 0)
139138
ENV_ENABLE_SOT_STEP_PROFILER = BooleanEnvironmentVariable(
140139
"ENABLE_SOT_STEP_PROFILER", False
141140
)

0 commit comments

Comments
 (0)