@@ -36,6 +36,7 @@ class LoggerConnector:
36
36
def __init__ (self , trainer ):
37
37
self .trainer = trainer
38
38
self .callback_metrics = {}
39
+ self .evaluation_callback_metrics = {}
39
40
self .logged_metrics = {}
40
41
self .progress_bar_metrics = {}
41
42
self .eval_loop_results = []
@@ -59,10 +60,9 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc
59
60
on_epoch = on_epoch )
60
61
61
62
def on_evaluation_batch_start (self , testing , batch , dataloader_idx , num_dataloaders ):
62
- # reset the result of the PL module
63
63
model = self .trainer .get_model ()
64
+ # set dataloader_idx only if multiple ones
64
65
model ._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
65
-
66
66
# track batch_size
67
67
self .cached_results ._batch_size = Result .extract_batch_size (batch )
68
68
@@ -226,19 +226,41 @@ def add_progress_bar_metrics(self, metrics):
226
226
227
227
self .trainer .dev_debugger .track_pbar_metrics_history (metrics )
228
228
229
- def on_evaluation_epoch_end (self , deprecated_eval_results , epoch_logs , using_eval_result , test_mode ):
229
+ def track_metrics_deprecated (self , deprecated_eval_results , using_eval_result , test_mode ):
230
230
self ._track_callback_metrics (deprecated_eval_results , using_eval_result )
231
-
232
- # TODO: deprecate parts of this for 1.0 (when removing results)
233
231
self .__process_eval_epoch_end_results_and_log_legacy (deprecated_eval_results , test_mode )
234
232
235
- self ._log_on_evaluation_epoch_end_metrics (epoch_logs )
233
+ def evaluation_epoch_end (self , testing ):
234
+ # reset dataloader idx
235
+ model_ref = self .trainer .get_model ()
236
+ model_ref ._current_dataloader_idx = None
237
+
238
+ # setting `has_batch_loop_finished` to True
239
+ # will perform Results reduction accross entire epoch.
240
+ self .cached_results .has_batch_loop_finished = True
241
+
242
+ def add_to_eval_loop_results (self , dl_idx , has_been_initialized ):
243
+ callback_metrics = deepcopy (self .evaluation_callback_metrics )
244
+ for key in list (callback_metrics .keys ()):
245
+ if "dataloader_idx" in key :
246
+ if f"dataloader_idx_{ dl_idx } " not in key :
247
+ # remove dl_idx from self.callback_metrics not belonging to this dataset.
248
+ del callback_metrics [key ]
249
+ if has_been_initialized :
250
+ self .eval_loop_results [dl_idx ].update (callback_metrics )
251
+ else :
252
+ self .eval_loop_results .append (callback_metrics )
236
253
237
- # get the final loop results
238
- eval_loop_results = self ._get_evaluate_epoch_results (test_mode )
239
- return eval_loop_results
254
+ def prepare_eval_loop_results (self ):
255
+ num_dataloaders = self .trainer .evaluation_loop .num_dataloaders
256
+ has_been_initialized = len (self .eval_loop_results ) == num_dataloaders
257
+ for dl_idx in range (self .trainer .evaluation_loop .num_dataloaders ):
258
+ self .add_to_eval_loop_results (dl_idx , has_been_initialized )
259
+
260
+ def get_evaluate_epoch_results (self , test_mode ):
261
+
262
+ self .prepare_eval_loop_results ()
240
263
241
- def _get_evaluate_epoch_results (self , test_mode ):
242
264
# log results of test
243
265
if test_mode and self .trainer .is_global_zero and self .trainer .verbose_test :
244
266
print ('-' * 80 )
@@ -253,106 +275,6 @@ def _get_evaluate_epoch_results(self, test_mode):
253
275
self .eval_loop_results = []
254
276
return results
255
277
256
- def _log_on_evaluation_epoch_end_metrics (self , epoch_logs ):
257
- step_metrics = self .trainer .evaluation_loop .step_metrics
258
-
259
- num_loaders = len (step_metrics )
260
-
261
- # clear mem
262
- self .trainer .evaluation_loop .step_metrics = []
263
-
264
- if self .trainer .running_sanity_check :
265
- return
266
-
267
- # track all metrics we want to log
268
- metrics_to_log = []
269
-
270
- # ---------------------------
271
- # UPDATE EPOCH LOGGED METRICS
272
- # ---------------------------
273
- # (ie: in methods at the val_epoch_end level)
274
- # union the epoch logs with whatever was returned from loaders and reduced
275
- epoch_logger_metrics = epoch_logs .get_epoch_log_metrics ()
276
- epoch_pbar_metrics = epoch_logs .get_epoch_pbar_metrics ()
277
-
278
- self .logged_metrics .update (epoch_logger_metrics )
279
- self .add_progress_bar_metrics (epoch_pbar_metrics )
280
-
281
- # enable the metrics to be monitored
282
- self .callback_metrics .update (epoch_logger_metrics )
283
- self .callback_metrics .update (epoch_pbar_metrics )
284
-
285
- if len (epoch_logger_metrics ) > 0 :
286
- metrics_to_log .append (epoch_logger_metrics )
287
-
288
- # --------------------------------
289
- # UPDATE METRICS PER DATALOADER
290
- # --------------------------------
291
- # each dataloader aggregated metrics
292
- # now we log all of them
293
- for dl_idx , dl_metrics in enumerate (step_metrics ):
294
- if len (dl_metrics ) == 0 :
295
- # Ensure custom logged metrics are included if not included with step metrics
296
- if len (epoch_logger_metrics ) > 0 :
297
- self .eval_loop_results .append (epoch_logger_metrics )
298
- continue
299
-
300
- reduced_epoch_metrics = dl_metrics [0 ].__class__ .reduce_on_epoch_end (dl_metrics )
301
- # track the metrics
302
- logger_metrics = reduced_epoch_metrics .get_epoch_log_metrics ()
303
- pbar_metrics = reduced_epoch_metrics .get_epoch_pbar_metrics ()
304
- forked_metrics = reduced_epoch_metrics .get_forked_metrics ()
305
-
306
- # make the keys 'k/dl'
307
- logger_metrics = self .__rename_keys_by_dataloader_idx (logger_metrics , dl_idx , num_loaders )
308
- pbar_metrics = self .__rename_keys_by_dataloader_idx (pbar_metrics , dl_idx , num_loaders )
309
- forked_metrics = self .__rename_keys_by_dataloader_idx (forked_metrics , dl_idx , num_loaders )
310
-
311
- self .logged_metrics .update (logger_metrics )
312
- self .add_progress_bar_metrics (pbar_metrics )
313
-
314
- # enable the metrics to be monitored
315
- self .callback_metrics .update (logger_metrics )
316
- self .callback_metrics .update (pbar_metrics )
317
-
318
- # forked metrics were dropped, enable them for callbacks
319
- self .callback_metrics .update (forked_metrics )
320
-
321
- # track the final results for the dataloader
322
- self .add_to_eval_loop_results (dl_idx , num_loaders )
323
-
324
- # actually log
325
- if len (logger_metrics ) > 0 :
326
- metrics_to_log .append (logger_metrics )
327
-
328
- # log all the metrics as a s single dict
329
- metrics_to_log = dict (ChainMap (* metrics_to_log ))
330
- if len (metrics_to_log ) > 0 :
331
- self .log_metrics (metrics_to_log , {})
332
-
333
- def add_to_eval_loop_results (self , dl_idx , num_loaders ):
334
- callback_metrics = deepcopy (self .callback_metrics )
335
- if num_loaders == 1 :
336
- if len (self .eval_loop_results ) > 0 :
337
- self .eval_loop_results [0 ].update (callback_metrics )
338
- else :
339
- self .eval_loop_results .append (callback_metrics )
340
- return
341
-
342
- for key in list (callback_metrics .keys ()):
343
- if "dataloader_idx" in key :
344
- if f"dataloader_idx_{ dl_idx } " not in key :
345
- # remove dl_idx from self.callback_metrics not belonging to this dataset.
346
- del callback_metrics [key ]
347
- self .eval_loop_results .append (callback_metrics )
348
-
349
- def __rename_keys_by_dataloader_idx (self , metrics , dataloader_idx , num_loaders ):
350
- if num_loaders == 1 :
351
- return metrics
352
-
353
- result = {f'{ k } /dataloader_idx_{ dataloader_idx } ' : v for k , v in metrics .items ()}
354
- return result
355
-
356
278
def _track_callback_metrics (self , eval_results , using_eval_result ):
357
279
if (
358
280
len (eval_results ) > 0 and
@@ -364,8 +286,10 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
364
286
if isinstance (eval_results , list ):
365
287
for eval_result in eval_results :
366
288
self .trainer .logger_connector .callback_metrics .update (eval_result .callback_metrics )
289
+ self .trainer .logger_connector .evaluation_callback_metrics .update (eval_result .callback_metrics )
367
290
else :
368
291
self .trainer .logger_connector .callback_metrics .update (eval_results .callback_metrics )
292
+ self .trainer .logger_connector .evaluation_callback_metrics .update (eval_results .callback_metrics )
369
293
else :
370
294
flat = {}
371
295
if isinstance (eval_results , list ):
@@ -381,6 +305,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
381
305
flat ['checkpoint_on' ] = flat ['val_loss' ]
382
306
flat ['early_stop_on' ] = flat ['val_loss' ]
383
307
self .trainer .logger_connector .callback_metrics .update (flat )
308
+ self .trainer .logger_connector .evaluation_callback_metrics .update (flat )
384
309
else :
385
310
# with a scalar return, auto set it to "val_loss" for callbacks
386
311
if isinstance (eval_results , torch .Tensor ):
@@ -393,6 +318,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
393
318
flat ['checkpoint_on' ] = flat ['val_loss' ]
394
319
flat ['early_stop_on' ] = flat ['val_loss' ]
395
320
self .trainer .logger_connector .callback_metrics .update (flat )
321
+ self .trainer .logger_connector .evaluation_callback_metrics .update (flat )
396
322
397
323
def __process_eval_epoch_end_results_and_log_legacy_update (self , prog_bar_metrics , log_metrics , callback_metrics ):
398
324
# eval loop returns all metrics
@@ -406,9 +332,10 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
406
332
self .trainer .logger_connector .log_metrics (log_metrics , {})
407
333
408
334
# track metrics for callbacks (all prog bar, logged and callback metrics)
335
+ callback_metrics .update (log_metrics )
336
+ callback_metrics .update (prog_bar_metrics )
409
337
self .trainer .logger_connector .callback_metrics .update (callback_metrics )
410
- self .trainer .logger_connector .callback_metrics .update (log_metrics )
411
- self .trainer .logger_connector .callback_metrics .update (prog_bar_metrics )
338
+ self .trainer .logger_connector .evaluation_callback_metrics .update (callback_metrics )
412
339
413
340
if len (dataloader_result_metrics ) > 0 :
414
341
self .eval_loop_results .append (dataloader_result_metrics )
0 commit comments