Skip to content

Commit eeefeee

Browse files
tchatonBordaSeanNaren
authored andcommitted
[FEAT] Refactor logging 3/3 [v1] (#4552)
* wip * wip check how many tests break * wip * resolve some bugs * resolve more bugs * resolve 2 bugs * resolve * temp fix * update * remove useless code * remove result * try to resolve bug * update changelog * formatting * remove pl Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Sean Naren <[email protected]>
1 parent 5ba7dcc commit eeefeee

File tree

9 files changed

+501
-214
lines changed

9 files changed

+501
-214
lines changed

CHANGELOG.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
3434

3535

36-
- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
36+
- Added logging using `self.log` in train and evaluation for most callbacks and model hooks (
37+
[#4552](https://github.com/PyTorchLightning/pytorch-lightning/pull/4552),
38+
[#4495](https://github.com/PyTorchLightning/pytorch-lightning/pull/4495),
39+
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)
40+
)
3741

42+
- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
3843

3944
- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))
4045

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from copy import deepcopy
1516
from collections import defaultdict, ChainMap
1617
from enum import Enum
1718
from typing import Union, Tuple, Any, Dict, Optional, List
@@ -419,13 +420,14 @@ def update_logger_connector(self, fx_name: str = None) -> None:
419420
logger_connector = self.trainer.logger_connector
420421

421422
callback_metrics = {}
423+
is_train = self._stage in LoggerStages.TRAIN.value
422424

423425
if not self._has_batch_loop_finished:
424426
# get pbar
425427
batch_pbar_metrics = self.get_latest_batch_pbar_metrics()
426428
logger_connector.add_progress_bar_metrics(batch_pbar_metrics)
427429

428-
if self._stage in LoggerStages.TRAIN.value:
430+
if is_train:
429431
# Only log and add to callback epoch step during evaluation, test.
430432
batch_log_metrics = self.get_latest_batch_log_metrics()
431433
logger_connector.logged_metrics.update(batch_log_metrics)
@@ -443,6 +445,9 @@ def update_logger_connector(self, fx_name: str = None) -> None:
443445
epoch_log_metrics = self.get_epoch_log_metrics()
444446
logger_connector.logged_metrics.update(epoch_log_metrics)
445447
logger_connector.logged_metrics.update(epoch_dict)
448+
if not self.trainer.running_sanity_check and not is_train:
449+
if len(epoch_log_metrics) > 0:
450+
self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics))
446451

447452
# get forked_metrics
448453
forked_metrics = self.get_forked_metrics()
@@ -451,6 +456,9 @@ def update_logger_connector(self, fx_name: str = None) -> None:
451456
callback_metrics.update(epoch_log_metrics)
452457
callback_metrics.update(forked_metrics)
453458

459+
if not is_train:
460+
logger_connector.evaluation_callback_metrics.update(callback_metrics)
461+
454462
# update callback_metrics
455463
logger_connector.callback_metrics.update(callback_metrics)
456464
logger_connector.callback_metrics.pop("epoch", None)

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

+39-112
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class LoggerConnector:
3636
def __init__(self, trainer):
3737
self.trainer = trainer
3838
self.callback_metrics = {}
39+
self.evaluation_callback_metrics = {}
3940
self.logged_metrics = {}
4041
self.progress_bar_metrics = {}
4142
self.eval_loop_results = []
@@ -59,10 +60,9 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc
5960
on_epoch=on_epoch)
6061

6162
def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
62-
# reset the result of the PL module
6363
model = self.trainer.get_model()
64+
# set dataloader_idx only if multiple ones
6465
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
65-
6666
# track batch_size
6767
self.cached_results._batch_size = Result.extract_batch_size(batch)
6868

@@ -226,19 +226,41 @@ def add_progress_bar_metrics(self, metrics):
226226

227227
self.trainer.dev_debugger.track_pbar_metrics_history(metrics)
228228

229-
def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode):
229+
def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode):
230230
self._track_callback_metrics(deprecated_eval_results, using_eval_result)
231-
232-
# TODO: deprecate parts of this for 1.0 (when removing results)
233231
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode)
234232

235-
self._log_on_evaluation_epoch_end_metrics(epoch_logs)
233+
def evaluation_epoch_end(self, testing):
234+
# reset dataloader idx
235+
model_ref = self.trainer.get_model()
236+
model_ref._current_dataloader_idx = None
237+
238+
# setting `has_batch_loop_finished` to True
239+
# will perform Results reduction accross entire epoch.
240+
self.cached_results.has_batch_loop_finished = True
241+
242+
def add_to_eval_loop_results(self, dl_idx, has_been_initialized):
243+
callback_metrics = deepcopy(self.evaluation_callback_metrics)
244+
for key in list(callback_metrics.keys()):
245+
if "dataloader_idx" in key:
246+
if f"dataloader_idx_{dl_idx}" not in key:
247+
# remove dl_idx from self.callback_metrics not belonging to this dataset.
248+
del callback_metrics[key]
249+
if has_been_initialized:
250+
self.eval_loop_results[dl_idx].update(callback_metrics)
251+
else:
252+
self.eval_loop_results.append(callback_metrics)
236253

237-
# get the final loop results
238-
eval_loop_results = self._get_evaluate_epoch_results(test_mode)
239-
return eval_loop_results
254+
def prepare_eval_loop_results(self):
255+
num_dataloaders = self.trainer.evaluation_loop.num_dataloaders
256+
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
257+
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
258+
self.add_to_eval_loop_results(dl_idx, has_been_initialized)
259+
260+
def get_evaluate_epoch_results(self, test_mode):
261+
262+
self.prepare_eval_loop_results()
240263

241-
def _get_evaluate_epoch_results(self, test_mode):
242264
# log results of test
243265
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
244266
print('-' * 80)
@@ -253,106 +275,6 @@ def _get_evaluate_epoch_results(self, test_mode):
253275
self.eval_loop_results = []
254276
return results
255277

256-
def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
257-
step_metrics = self.trainer.evaluation_loop.step_metrics
258-
259-
num_loaders = len(step_metrics)
260-
261-
# clear mem
262-
self.trainer.evaluation_loop.step_metrics = []
263-
264-
if self.trainer.running_sanity_check:
265-
return
266-
267-
# track all metrics we want to log
268-
metrics_to_log = []
269-
270-
# ---------------------------
271-
# UPDATE EPOCH LOGGED METRICS
272-
# ---------------------------
273-
# (ie: in methods at the val_epoch_end level)
274-
# union the epoch logs with whatever was returned from loaders and reduced
275-
epoch_logger_metrics = epoch_logs.get_epoch_log_metrics()
276-
epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics()
277-
278-
self.logged_metrics.update(epoch_logger_metrics)
279-
self.add_progress_bar_metrics(epoch_pbar_metrics)
280-
281-
# enable the metrics to be monitored
282-
self.callback_metrics.update(epoch_logger_metrics)
283-
self.callback_metrics.update(epoch_pbar_metrics)
284-
285-
if len(epoch_logger_metrics) > 0:
286-
metrics_to_log.append(epoch_logger_metrics)
287-
288-
# --------------------------------
289-
# UPDATE METRICS PER DATALOADER
290-
# --------------------------------
291-
# each dataloader aggregated metrics
292-
# now we log all of them
293-
for dl_idx, dl_metrics in enumerate(step_metrics):
294-
if len(dl_metrics) == 0:
295-
# Ensure custom logged metrics are included if not included with step metrics
296-
if len(epoch_logger_metrics) > 0:
297-
self.eval_loop_results.append(epoch_logger_metrics)
298-
continue
299-
300-
reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics)
301-
# track the metrics
302-
logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics()
303-
pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics()
304-
forked_metrics = reduced_epoch_metrics.get_forked_metrics()
305-
306-
# make the keys 'k/dl'
307-
logger_metrics = self.__rename_keys_by_dataloader_idx(logger_metrics, dl_idx, num_loaders)
308-
pbar_metrics = self.__rename_keys_by_dataloader_idx(pbar_metrics, dl_idx, num_loaders)
309-
forked_metrics = self.__rename_keys_by_dataloader_idx(forked_metrics, dl_idx, num_loaders)
310-
311-
self.logged_metrics.update(logger_metrics)
312-
self.add_progress_bar_metrics(pbar_metrics)
313-
314-
# enable the metrics to be monitored
315-
self.callback_metrics.update(logger_metrics)
316-
self.callback_metrics.update(pbar_metrics)
317-
318-
# forked metrics were dropped, enable them for callbacks
319-
self.callback_metrics.update(forked_metrics)
320-
321-
# track the final results for the dataloader
322-
self.add_to_eval_loop_results(dl_idx, num_loaders)
323-
324-
# actually log
325-
if len(logger_metrics) > 0:
326-
metrics_to_log.append(logger_metrics)
327-
328-
# log all the metrics as a s single dict
329-
metrics_to_log = dict(ChainMap(*metrics_to_log))
330-
if len(metrics_to_log) > 0:
331-
self.log_metrics(metrics_to_log, {})
332-
333-
def add_to_eval_loop_results(self, dl_idx, num_loaders):
334-
callback_metrics = deepcopy(self.callback_metrics)
335-
if num_loaders == 1:
336-
if len(self.eval_loop_results) > 0:
337-
self.eval_loop_results[0].update(callback_metrics)
338-
else:
339-
self.eval_loop_results.append(callback_metrics)
340-
return
341-
342-
for key in list(callback_metrics.keys()):
343-
if "dataloader_idx" in key:
344-
if f"dataloader_idx_{dl_idx}" not in key:
345-
# remove dl_idx from self.callback_metrics not belonging to this dataset.
346-
del callback_metrics[key]
347-
self.eval_loop_results.append(callback_metrics)
348-
349-
def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
350-
if num_loaders == 1:
351-
return metrics
352-
353-
result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()}
354-
return result
355-
356278
def _track_callback_metrics(self, eval_results, using_eval_result):
357279
if (
358280
len(eval_results) > 0 and
@@ -364,8 +286,10 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
364286
if isinstance(eval_results, list):
365287
for eval_result in eval_results:
366288
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
289+
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics)
367290
else:
368291
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
292+
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
369293
else:
370294
flat = {}
371295
if isinstance(eval_results, list):
@@ -381,6 +305,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
381305
flat['checkpoint_on'] = flat['val_loss']
382306
flat['early_stop_on'] = flat['val_loss']
383307
self.trainer.logger_connector.callback_metrics.update(flat)
308+
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
384309
else:
385310
# with a scalar return, auto set it to "val_loss" for callbacks
386311
if isinstance(eval_results, torch.Tensor):
@@ -393,6 +318,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
393318
flat['checkpoint_on'] = flat['val_loss']
394319
flat['early_stop_on'] = flat['val_loss']
395320
self.trainer.logger_connector.callback_metrics.update(flat)
321+
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
396322

397323
def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
398324
# eval loop returns all metrics
@@ -406,9 +332,10 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
406332
self.trainer.logger_connector.log_metrics(log_metrics, {})
407333

408334
# track metrics for callbacks (all prog bar, logged and callback metrics)
335+
callback_metrics.update(log_metrics)
336+
callback_metrics.update(prog_bar_metrics)
409337
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
410-
self.trainer.logger_connector.callback_metrics.update(log_metrics)
411-
self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics)
338+
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
412339

413340
if len(dataloader_result_metrics) > 0:
414341
self.eval_loop_results.append(dataloader_result_metrics)

0 commit comments

Comments
 (0)