Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Refactor logging 3/3 [v1] #4552

Merged
merged 27 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
63d7535
wip
tchaton Nov 6, 2020
5eaa330
wip check how many tests break
tchaton Nov 6, 2020
95ec670
Merge branch 'master' into feat/logging_val
tchaton Nov 6, 2020
c269a8b
wip
tchaton Nov 6, 2020
2444640
resolve some bugs
tchaton Nov 6, 2020
3726577
resolve more bugs
tchaton Nov 6, 2020
1855fa9
resolve 2 bugs
tchaton Nov 6, 2020
a83a313
resolve
tchaton Nov 6, 2020
93fd11a
Merge branch 'master' into feat/logging_val_1
tchaton Nov 6, 2020
fec1b20
temp fix
tchaton Nov 6, 2020
f75cff6
Merge branch 'feat/logging_val_1' of https://github.com/PyTorchLightn…
tchaton Nov 6, 2020
6038412
update
tchaton Nov 6, 2020
a7a44ce
remove useless code
tchaton Nov 6, 2020
fef917e
remove result
tchaton Nov 6, 2020
f3e47c9
try to resolve bug
tchaton Nov 6, 2020
442d9e7
update changelog
tchaton Nov 6, 2020
c21e745
formatting
Borda Nov 6, 2020
3100edc
remove pl
tchaton Nov 7, 2020
1f3a469
Merge branch 'master' into feat/logging_val_1
SeanNaren Nov 7, 2020
c04897e
Merge branch 'master' into feat/logging_val_1
tchaton Nov 9, 2020
c8cd585
Merge branch 'master' into feat/logging_val_1
SeanNaren Nov 9, 2020
24b0871
Merge branch 'master' into feat/logging_val_1
tchaton Nov 10, 2020
496020d
Merge branch 'feat/logging_val_1' of https://github.com/PyTorchLightn…
tchaton Nov 10, 2020
67ad75b
Merge branch 'master' into feat/logging_val_1
tchaton Nov 10, 2020
4828ee4
Merge branch 'master' into feat/logging_val_1
tchaton Nov 10, 2020
1520bcc
Merge branch 'master' into feat/logging_val_1
tchaton Nov 10, 2020
8592000
Merge branch 'master' into feat/logging_val_1
tchaton Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458))


- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))


- Added logging using `self.log` in train and evaluation for most callbacks and model hooks ([#4552](https://github.com/PyTorchLightning/pytorch-lightning/pull/4552),
[#4495](https://github.com/PyTorchLightning/pytorch-lightning/pull/4495),
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))


### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from collections import defaultdict, ChainMap
from enum import Enum
from typing import Union, Tuple, Any, Dict, Optional, List
Expand Down Expand Up @@ -415,13 +416,14 @@ def update_logger_connector(self, fx_name: str = None) -> None:
logger_connector = self.trainer.logger_connector

callback_metrics = {}
is_train = self._stage in LoggerStages.TRAIN.value

if not self._has_batch_loop_finished:
# get pbar
batch_pbar_metrics = self.get_latest_batch_pbar_metrics()
logger_connector.add_progress_bar_metrics(batch_pbar_metrics)

if self._stage in LoggerStages.TRAIN.value:
if is_train:
# Only log and add to callback epoch step during evaluation, test.
batch_log_metrics = self.get_latest_batch_log_metrics()
logger_connector.logged_metrics.update(batch_log_metrics)
Expand All @@ -439,6 +441,9 @@ def update_logger_connector(self, fx_name: str = None) -> None:
epoch_log_metrics = self.get_epoch_log_metrics()
logger_connector.logged_metrics.update(epoch_log_metrics)
logger_connector.logged_metrics.update(epoch_dict)
if not self.trainer.running_sanity_check and not is_train:
if len(epoch_log_metrics) > 0:
self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics))
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# get forked_metrics
forked_metrics = self.get_forked_metrics()
Expand All @@ -447,6 +452,9 @@ def update_logger_connector(self, fx_name: str = None) -> None:
callback_metrics.update(epoch_log_metrics)
callback_metrics.update(forked_metrics)

if not is_train:
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
logger_connector.callback_metrics.update(callback_metrics)
logger_connector.callback_metrics.pop("epoch", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LoggerConnector:
def __init__(self, trainer):
self.trainer = trainer
self.callback_metrics = {}
self.evaluation_callback_metrics = {}
self.logged_metrics = {}
self.progress_bar_metrics = {}
self.eval_loop_results = []
Expand All @@ -59,10 +60,9 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc
on_epoch=on_epoch)

def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
# reset the result of the PL module
model = self.trainer.get_model()
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
self.cached_results._batch_size = Result.extract_batch_size(batch)

Expand Down Expand Up @@ -224,19 +224,41 @@ def add_progress_bar_metrics(self, metrics):

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode):
def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode):
self._track_callback_metrics(deprecated_eval_results, using_eval_result)

# TODO: deprecate parts of this for 1.0 (when removing results)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode)

self._log_on_evaluation_epoch_end_metrics(epoch_logs)
def evaluation_epoch_end(self, testing):
# reset dataloader idx
model_ref = self.trainer.get_model()
model_ref._current_dataloader_idx = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# setting `has_batch_loop_finished` to True
# will perform Results reduction accross entire epoch.
self.cached_results.has_batch_loop_finished = True

def add_to_eval_loop_results(self, dl_idx, has_been_initialized):
callback_metrics = deepcopy(self.evaluation_callback_metrics)
for key in list(callback_metrics.keys()):
if "dataloader_idx" in key:
if f"dataloader_idx_{dl_idx}" not in key:
# remove dl_idx from self.callback_metrics not belonging to this dataset.
del callback_metrics[key]
if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
self.eval_loop_results.append(callback_metrics)

# get the final loop results
eval_loop_results = self._get_evaluate_epoch_results(test_mode)
return eval_loop_results
def prepare_eval_loop_results(self):
num_dataloaders = self.trainer.evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self, test_mode):

self.prepare_eval_loop_results()

def _get_evaluate_epoch_results(self, test_mode):
# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
print('-' * 80)
Expand All @@ -251,106 +273,6 @@ def _get_evaluate_epoch_results(self, test_mode):
self.eval_loop_results = []
return results

def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
step_metrics = self.trainer.evaluation_loop.step_metrics

num_loaders = len(step_metrics)

# clear mem
self.trainer.evaluation_loop.step_metrics = []

if self.trainer.running_sanity_check:
return

# track all metrics we want to log
metrics_to_log = []

# ---------------------------
# UPDATE EPOCH LOGGED METRICS
# ---------------------------
# (ie: in methods at the val_epoch_end level)
# union the epoch logs with whatever was returned from loaders and reduced
epoch_logger_metrics = epoch_logs.get_epoch_log_metrics()
epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics()

self.logged_metrics.update(epoch_logger_metrics)
self.add_progress_bar_metrics(epoch_pbar_metrics)

# enable the metrics to be monitored
self.callback_metrics.update(epoch_logger_metrics)
self.callback_metrics.update(epoch_pbar_metrics)

if len(epoch_logger_metrics) > 0:
metrics_to_log.append(epoch_logger_metrics)

# --------------------------------
# UPDATE METRICS PER DATALOADER
# --------------------------------
# each dataloader aggregated metrics
# now we log all of them
for dl_idx, dl_metrics in enumerate(step_metrics):
if len(dl_metrics) == 0:
# Ensure custom logged metrics are included if not included with step metrics
if len(epoch_logger_metrics) > 0:
self.eval_loop_results.append(epoch_logger_metrics)
continue

reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics)
# track the metrics
logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics()
pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics()
forked_metrics = reduced_epoch_metrics.get_forked_metrics()

# make the keys 'k/dl'
logger_metrics = self.__rename_keys_by_dataloader_idx(logger_metrics, dl_idx, num_loaders)
pbar_metrics = self.__rename_keys_by_dataloader_idx(pbar_metrics, dl_idx, num_loaders)
forked_metrics = self.__rename_keys_by_dataloader_idx(forked_metrics, dl_idx, num_loaders)

self.logged_metrics.update(logger_metrics)
self.add_progress_bar_metrics(pbar_metrics)

# enable the metrics to be monitored
self.callback_metrics.update(logger_metrics)
self.callback_metrics.update(pbar_metrics)

# forked metrics were dropped, enable them for callbacks
self.callback_metrics.update(forked_metrics)

# track the final results for the dataloader
self.add_to_eval_loop_results(dl_idx, num_loaders)

# actually log
if len(logger_metrics) > 0:
metrics_to_log.append(logger_metrics)

# log all the metrics as a s single dict
metrics_to_log = dict(ChainMap(*metrics_to_log))
if len(metrics_to_log) > 0:
self.log_metrics(metrics_to_log, {})

def add_to_eval_loop_results(self, dl_idx, num_loaders):
callback_metrics = deepcopy(self.callback_metrics)
if num_loaders == 1:
if len(self.eval_loop_results) > 0:
self.eval_loop_results[0].update(callback_metrics)
else:
self.eval_loop_results.append(callback_metrics)
return

for key in list(callback_metrics.keys()):
if "dataloader_idx" in key:
if f"dataloader_idx_{dl_idx}" not in key:
# remove dl_idx from self.callback_metrics not belonging to this dataset.
del callback_metrics[key]
self.eval_loop_results.append(callback_metrics)

def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
return metrics

result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()}
return result

def _track_callback_metrics(self, eval_results, using_eval_result):
if (
len(eval_results) > 0 and
Expand All @@ -362,8 +284,10 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
if isinstance(eval_results, list):
for eval_result in eval_results:
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics)
else:
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
else:
flat = {}
if isinstance(eval_results, list):
Expand All @@ -379,6 +303,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
Expand All @@ -391,6 +316,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
# eval loop returns all metrics
Expand All @@ -404,9 +330,10 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
self.trainer.logger_connector.log_metrics(log_metrics, {})

# track metrics for callbacks (all prog bar, logged and callback metrics)
callback_metrics.update(log_metrics)
callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
self.trainer.logger_connector.callback_metrics.update(log_metrics)
self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)

if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)
Expand Down
Loading