Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Add logging test / val callback functions #4351

Closed
wants to merge 80 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
ac40d1b
release only training callback
tchaton Oct 20, 2020
8231ae3
update for flake8
tchaton Oct 20, 2020
d6eb8d8
resolve `test_multiple_optimizers_manual_return_and_log`
tchaton Oct 20, 2020
7ae6f79
resolve `test_multiple_optimizers_manual_return`
tchaton Oct 20, 2020
9434d11
release only training callback
tchaton Oct 20, 2020
4741258
update for flake8
tchaton Oct 20, 2020
77fae0e
resolve `test_multiple_optimizers_manual_return_and_log`
tchaton Oct 20, 2020
b47b390
resolve `test_multiple_optimizers_manual_return`
tchaton Oct 20, 2020
2a3c72d
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
542d8d3
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
290a160
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
0dfe8c9
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
104c9f5
remove mixin
tchaton Oct 21, 2020
abe57d3
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 21, 2020
f4e2477
remove explicit mixin
tchaton Oct 21, 2020
974e742
add support for logging within val / test
tchaton Oct 22, 2020
1313c2e
release only training callback
tchaton Oct 20, 2020
bb6554b
update for flake8
tchaton Oct 20, 2020
a8ca9f4
resolve `test_multiple_optimizers_manual_return_and_log`
tchaton Oct 20, 2020
cc3b0be
resolve `test_multiple_optimizers_manual_return`
tchaton Oct 20, 2020
559705b
remove mixin
tchaton Oct 21, 2020
455cbe3
release only training callback
tchaton Oct 20, 2020
468f235
update for flake8
tchaton Oct 20, 2020
6b33625
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
bea52b9
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
442f8fc
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
b08497d
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
8f125b5
remove explicit mixin
tchaton Oct 21, 2020
9074473
add support for logging within val / test
tchaton Oct 22, 2020
f59d10d
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
bddd61d
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
0740f1e
resolve logging bug
tchaton Oct 22, 2020
82fc4fe
repair bug
tchaton Oct 22, 2020
075d5bf
resolve pep8
tchaton Oct 22, 2020
f71e588
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
36dbc96
resolve formatting bug
tchaton Oct 22, 2020
81ca911
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 22, 2020
25242fb
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
1a6b172
check if metric and grad_norm_dic is defined
tchaton Oct 22, 2020
4690ff5
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 22, 2020
c7c1e7d
resolve pep8
tchaton Oct 22, 2020
1dbe60c
resolve typo
tchaton Oct 22, 2020
54e2799
convert metris and grad_norm_dic to dict when None
tchaton Oct 22, 2020
8a8b54a
resolve pep8
tchaton Oct 22, 2020
3b19cb7
Merge branch 'master' into feat/logging_in_val_test_callbacks
tchaton Oct 22, 2020
2c9397d
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tchaton Oct 22, 2020
04ef001
Merge branch 'feat/logging_in_val_test_callbacks' of https://github.c…
tchaton Oct 22, 2020
ff46b30
remove previous test
tchaton Oct 23, 2020
2d4bda1
Merge branch 'master' into feat/logging_in_val_test_callbacks
tchaton Oct 23, 2020
e1eef34
cleanup
tchaton Oct 23, 2020
82ec5e8
resolve flake8
tchaton Oct 23, 2020
de16fee
resolve bug
tchaton Oct 23, 2020
43da42a
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
1549735
add type
tchaton Oct 23, 2020
e1652bf
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
536f85f
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 23, 2020
e8c2fb7
Merge branch 'master' into feat/logging_in_val_test_callbacks
tchaton Oct 23, 2020
192d543
remove wrong merge
tchaton Oct 23, 2020
5756008
remove `debug_epoch` is `dev_debugger` is enabled
tchaton Oct 23, 2020
9118d6b
try to find bug
tchaton Oct 23, 2020
cff9644
update main eval_loop_results
tchaton Oct 23, 2020
a50d873
add ChainMap
tchaton Oct 23, 2020
c0ef818
try out
tchaton Oct 23, 2020
f7f514a
resolve pep8
tchaton Oct 23, 2020
740a107
reduce lenght
tchaton Oct 23, 2020
f8e1d49
update message error
tchaton Oct 23, 2020
18247fa
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
0f28e03
Merge branch 'master' into feat/logging_in_val_test_callbacks
tchaton Oct 23, 2020
d0e1e81
Merge branch 'master' into feat/logging_in_val_test_callbacks
SeanNaren Oct 23, 2020
1058e5e
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
6504731
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 24, 2020
d8678ff
Merge branch 'master' into feat/logging_in_val_test_callbacks
tchaton Oct 24, 2020
e90f2c0
move files ar
tchaton Oct 25, 2020
292af7d
create connector_logger_utils
tchaton Oct 25, 2020
abdbd9f
resolve flake8
tchaton Oct 25, 2020
8d1c924
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 25, 2020
39e0bdf
Merge branch 'FEATURE/logging_in_train_callbacks' into feat/logging_i…
tchaton Oct 25, 2020
54b9a93
Merge branch 'FEATURE/logging_in_train_callbacks' into feat/logging_i…
tchaton Oct 25, 2020
77b88e5
add dataloder_idx to meta
tchaton Oct 25, 2020
e327fc6
Merge branch 'master' into feat/logging_in_val_test_callbacks_2
tchaton Oct 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import collections
import copy
import inspect
import os
import re
import tempfile
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
get_init_args,
)
from pytorch_lightning.callbacks import Callback
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -111,6 +111,8 @@ def __init__(self, *args, **kwargs):
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._current_hook_fx_name = ''
self._current_dataloader_idx = None

def optimizers(self):
opts = self.trainer.optimizers
Expand Down Expand Up @@ -244,6 +246,17 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self._current_hook_fx_name != '':
self.trainer.logger_connector.callback_logging_validator\
.validate_callback_logging_arguments(self._current_hook_fx_name,
on_step=on_step,
on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"Logged key: {name} should not contain information about dataloader_idx.")

self._results.log(
name,
value,
Expand All @@ -257,7 +270,8 @@ def log(
enable_graph,
sync_dist,
sync_dist_op,
sync_dist_group
sync_dist_group,
self._current_dataloader_idx,
)

def log_dict(
Expand Down Expand Up @@ -950,7 +964,8 @@ def configure_optimizers(
- Single optimizer.
- List or Tuple - List of optimizers.
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
- Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' key which value is a single LR scheduler or lr_dict.
- Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' key which value is a single LR
scheduler or lr_dict.
- Tuple of dictionaries as described, with an optional 'frequency' key.
- None - Fit will run without any optimizer.

Expand Down Expand Up @@ -1278,11 +1293,11 @@ def tbptt_split_batch(self, batch, split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t : t + split_size]
split_x = x[:, t: t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t : t + split_size]
split_x[batch_idx] = x[batch_idx][t: t + split_size]

batch_split.append(split_x)

Expand Down
63 changes: 44 additions & 19 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def log(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
dataloader_idx: Optional[int] = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
Expand All @@ -144,6 +145,7 @@ def log(

# set step version
step_name = f'{name}_step'

self.__set_meta(
step_name,
value,
Expand All @@ -154,12 +156,15 @@ def log(
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
forked=False,
dataloader_idx=dataloader_idx,
)

self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'{name}_epoch'

self.__set_meta(
epoch_name,
value,
Expand All @@ -170,7 +175,8 @@ def log(
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
forked=False,
dataloader_idx=dataloader_idx,
)
self.__setitem__(epoch_name, value)

Expand All @@ -185,7 +191,8 @@ def log(
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=was_forked
forked=was_forked,
dataloader_idx=dataloader_idx,
)

# set the value
Expand All @@ -202,7 +209,8 @@ def __set_meta(
reduce_fx: Callable,
tbptt_pad_token: int,
tbptt_reduce_fx: Callable,
forked: bool
forked: bool,
dataloader_idx: Union[int, None]
):
# set the meta for the item
meta_value = value
Expand All @@ -215,7 +223,8 @@ def __set_meta(
value=meta_value,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=forked
forked=forked,
dataloader_idx=dataloader_idx,
)

self['meta'][name] = meta
Expand All @@ -242,7 +251,13 @@ def get_callback_metrics(self) -> dict:

return result

def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str:
if dataloader_idx is not None and add_dataloader_idx:
return f"{k}/dataloader_idx_{dataloader_idx}"
else:
return k

def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict:
"""
Gets the metrics to log at the end of the batch step

Expand All @@ -257,15 +272,17 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
if options['forked'] and not include_forked_originals:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['logger'] and options['on_step']:
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
result[dl_key] = self[k]._forward_cache
else:
result[k] = self[k]
result[dl_key] = self[k]

return result

def get_epoch_log_metrics(self) -> dict:
def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -279,19 +296,21 @@ def get_epoch_log_metrics(self) -> dict:
if options['forked']:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
result[dl_key] = self[k].compute()
else:
result[k] = self[k]
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()

return result

def get_epoch_pbar_metrics(self):
def get_epoch_pbar_metrics(self, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -305,19 +324,21 @@ def get_epoch_pbar_metrics(self):
if options['forked']:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
result[dl_key] = self[k].compute()
else:
result[k] = self[k]
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()

return result

def get_forked_metrics(self):
def get_forked_metrics(self, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -328,12 +349,14 @@ def get_forked_metrics(self):
if k == '_internal':
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['forked']:
result[k] = self[k]
result[dl_key] = self[k]

return result

def get_batch_pbar_metrics(self, include_forked_originals=True):
def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of the batch step
"""
Expand All @@ -347,11 +370,13 @@ def get_batch_pbar_metrics(self, include_forked_originals=True):
if options['forked'] and not include_forked_originals:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['prog_bar'] and options['on_step']:
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
result[dl_key] = self[k]._forward_cache
else:
result[k] = self[k]
result[dl_key] = self[k]

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -61,6 +62,8 @@ def init_default_checkpoint_callback(self, checkpoint_callback):
checkpoint_callback = ModelCheckpoint(dirpath=None, filename=None)
elif checkpoint_callback is False:
checkpoint_callback = None
if checkpoint_callback:
checkpoint_callback.save_function = self.trainer.save_checkpoint

return checkpoint_callback

Expand All @@ -81,5 +84,4 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0):
self.trainer.callbacks.append(progress_bar_callback)
else:
progress_bar_callback = None

return progress_bar_callback
Loading