|
18 | 18 | from typing import TYPE_CHECKING |
19 | 19 |
|
20 | 20 | import paddle |
| 21 | +from paddle.jit.profiler import EventGuard, event_register |
21 | 22 |
|
22 | 23 | from ..infer_meta import convert_meta_to_input_spec |
23 | | -from ..profiler import EventGuard |
24 | 24 | from ..utils import ( |
25 | 25 | ENV_SOT_EXPORT, |
26 | 26 | Cache, |
@@ -209,48 +209,49 @@ def update_compile_time_info(self, SIR, partial_program_layer): |
209 | 209 | code |
210 | 210 | ] += partial_program_layer._compile_time_counter.get_total_time() |
211 | 211 |
|
| 212 | + @event_register( |
| 213 | + lambda self, *args, **kwargs: f"FallbackWrapper: {self.SIR.name}" |
| 214 | + ) |
212 | 215 | def __call__(self, *args, **kwargs): |
213 | | - with EventGuard(f"FallbackWrapper: {self.SIR.name}"): |
214 | | - if StepInfoManager().need_back_trace: |
215 | | - trace_back_frames() |
| 216 | + if StepInfoManager().need_back_trace: |
| 217 | + trace_back_frames() |
216 | 218 |
|
217 | | - log_do( |
218 | | - 2, |
219 | | - lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR), |
220 | | - ) |
221 | | - log_do( |
222 | | - 4, |
223 | | - lambda: print( |
224 | | - self.compiled_fn.get_concrete_program(*args, **kwargs)[ |
225 | | - 1 |
226 | | - ].train_program |
227 | | - ), |
228 | | - ) |
229 | | - if self.partial_program is None: |
230 | | - with EventGuard("FallbackWrapper: get_concrete_program"): |
231 | | - ( |
232 | | - self.concrete_program, |
233 | | - self.partial_program, |
234 | | - ) = self.compiled_fn.get_concrete_program(*args, **kwargs) |
235 | | - self.partial_program.training = self.is_training |
236 | | - with EventGuard("FallbackWrapper: sot call partial_program"): |
237 | | - outputs = self.partial_program.sot_call(*args, **kwargs) |
238 | | - |
239 | | - clear_eager_tensor_name(outputs) |
240 | | - log_do( |
241 | | - 4, |
242 | | - lambda: print("[CompileCache] run sir forward success."), |
243 | | - ) |
244 | | - self.collect_new_symbol_hit_rate(args, outputs) |
245 | | - self.collect_subgraph_relation(args, outputs, self.partial_program) |
246 | | - self.collect_subgraph_info(self.concrete_program.main_program) |
247 | | - self.update_compile_time_info(self.SIR, self.partial_program) |
248 | | - if ENV_SOT_EXPORT.get() != "" and not self.exported: |
249 | | - export(self.SIR, ENV_SOT_EXPORT.get()) |
250 | | - self.exported = True |
251 | | - |
252 | | - self.is_first_call = False |
253 | | - return outputs |
| 219 | + log_do( |
| 220 | + 2, |
| 221 | + lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR), |
| 222 | + ) |
| 223 | + log_do( |
| 224 | + 4, |
| 225 | + lambda: print( |
| 226 | + self.compiled_fn.get_concrete_program(*args, **kwargs)[ |
| 227 | + 1 |
| 228 | + ].train_program |
| 229 | + ), |
| 230 | + ) |
| 231 | + if self.partial_program is None: |
| 232 | + with EventGuard("FallbackWrapper: get_concrete_program"): |
| 233 | + ( |
| 234 | + self.concrete_program, |
| 235 | + self.partial_program, |
| 236 | + ) = self.compiled_fn.get_concrete_program(*args, **kwargs) |
| 237 | + self.partial_program.training = self.is_training |
| 238 | + outputs = self.partial_program.sot_call(*args, **kwargs) |
| 239 | + |
| 240 | + clear_eager_tensor_name(outputs) |
| 241 | + log_do( |
| 242 | + 4, |
| 243 | + lambda: print("[CompileCache] run sir forward success."), |
| 244 | + ) |
| 245 | + self.collect_new_symbol_hit_rate(args, outputs) |
| 246 | + self.collect_subgraph_relation(args, outputs, self.partial_program) |
| 247 | + self.collect_subgraph_info(self.concrete_program.main_program) |
| 248 | + self.update_compile_time_info(self.SIR, self.partial_program) |
| 249 | + if ENV_SOT_EXPORT.get() != "" and not self.exported: |
| 250 | + export(self.SIR, ENV_SOT_EXPORT.get()) |
| 251 | + self.exported = True |
| 252 | + |
| 253 | + self.is_first_call = False |
| 254 | + return outputs |
254 | 255 |
|
255 | 256 |
|
256 | 257 | class CompileSIRCache(Cache, metaclass=Singleton): |
|
0 commit comments