Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Logging refactor 2/n - train #4495

Merged
merged 53 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2b3c4bc
update logging
tchaton Nov 3, 2020
e2814ad
Merge branch 'master' into feat/train_logging
tchaton Nov 3, 2020
f487a5d
Merge branch 'master' into feat/train_logging
tchaton Nov 3, 2020
ba0427f
solve more bugs
tchaton Nov 3, 2020
c68995a
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 3, 2020
8337394
replace Mapping by Dict
tchaton Nov 3, 2020
3862ef7
update on comments
tchaton Nov 3, 2020
23a62ac
resolve pep8
tchaton Nov 3, 2020
ebf6573
Merge branch 'master' into feat/train_logging
tchaton Nov 3, 2020
3921725
Apply suggestions from code review
Borda Nov 3, 2020
09ace23
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
c9308d4
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 4, 2020
e459131
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Nov 4, 2020
abd0fc0
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
a8371bf
update on comments
tchaton Nov 4, 2020
92994d9
typo
tchaton Nov 4, 2020
3539faa
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
f3b4f1f
update for coverage
tchaton Nov 4, 2020
453abed
update test
tchaton Nov 4, 2020
fb72bff
update
tchaton Nov 4, 2020
93c596d
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 4, 2020
0decd22
Update tests/models/test_hooks.py
tchaton Nov 4, 2020
005e91b
Update tests/models/test_hooks.py
tchaton Nov 4, 2020
21da81f
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
8f879db
update on comments
tchaton Nov 4, 2020
1983cc1
Merge branch 'master' into feat/train_logging
tchaton Nov 4, 2020
7b4e9e0
Merge branch 'master' into feat/train_logging
tchaton Nov 5, 2020
0e41cad
remove deepcopy
tchaton Nov 5, 2020
fcf74e5
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 5, 2020
25692c8
remove useless look for
tchaton Nov 5, 2020
2859f5c
another small optim
tchaton Nov 5, 2020
f0a13bb
extra optim
tchaton Nov 5, 2020
5535b0a
remove lastest optim, can be source of bug
tchaton Nov 5, 2020
ae0c00f
resolve bug
tchaton Nov 5, 2020
3e6fc63
add docstring
tchaton Nov 5, 2020
43f5c45
optimize coverage
tchaton Nov 5, 2020
aa393c3
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Nov 5, 2020
bc62cff
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Nov 5, 2020
85317ad
Update tests/trainer/logging_tests/test_distributed_logging.py
tchaton Nov 5, 2020
d492d94
Update pytorch_lightning/trainer/evaluation_loop.py
tchaton Nov 5, 2020
caea74c
Update tests/trainer/logging/test_logger_connector.py
tchaton Nov 5, 2020
5bc3847
Update tests/trainer/logging_tests/test_train_loop_logging_1_0.py
tchaton Nov 5, 2020
66a89a8
Merge branch 'master' into feat/train_logging
tchaton Nov 5, 2020
6c7373a
update on comments
tchaton Nov 5, 2020
60f95a8
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 5, 2020
ef2065c
update
tchaton Nov 5, 2020
6a9bcc5
update on comments
tchaton Nov 5, 2020
e22d9f8
update parity speed
tchaton Nov 5, 2020
395df7f
get it down to 0.65
tchaton Nov 5, 2020
ae64091
update
tchaton Nov 5, 2020
59ca975
Merge branch 'master' into feat/train_logging
williamFalcon Nov 5, 2020
4c19f96
0.8 max_dif
tchaton Nov 5, 2020
c86643b
Merge branch 'feat/train_logging' of https://github.com/PyTorchLightn…
tchaton Nov 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _on_validation_start_log():
@staticmethod
def _on_validation_end_log():
"""Called when the validation loop ends."""
return {"on_step": [False], "on_epoch": [False, True]}
return None

@staticmethod
def _on_test_start_log():
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,14 @@ def __init__(self, trainer):
self._callback_hook_validator = CallbackHookNameValidator()
self._current_stage = None

def cached_results(self, stage_or_testing: Union[str, bool]) -> Union[EpochResultStore, None]:
""" Function to access cached_results using str or bool. Bool is used only for testing"""
stage_or_testing = str(stage_or_testing)
stages = self._stages
if stage_or_testing in self._stages:
return self._cached_results[stage_or_testing]
if stage_or_testing in LOOKUP_TABLE:
# Acces using trainer.testing
stage = LOOKUP_TABLE[stage_or_testing]
return self._cached_results[stage]
raise MisconfigurationException(
f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}"
f" or {LOOKUP_TABLE.keys()}"
)
@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results[self._current_stage]

def set_stage(self, stage_or_testing: str, reset:bool = False) -> None:
self._current_stage = self._determine_stage(stage_or_testing)
if reset:
self.cached_results(stage_or_testing).reset()
self.cached_results.reset()

def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None:
self._callback_hook_validator.check_logging_in_callbacks(current_hook_fx_name=hook_fx_name,
Expand All @@ -75,17 +64,17 @@ def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataload
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
self.cached_results(testing)._batch_size = Result.extract_batch_size(batch)
self.cached_results._batch_size = Result.extract_batch_size(batch)

def on_batch_start(self, split_idx: int, opt_idx: int, split_batch) -> None:
self._cached_results["train"]._split_idx = split_idx
self._cached_results["train"]._opt_idx = opt_idx
self._cached_results["train"]._batch_size = Result.extract_batch_size(split_batch)
def on_train_split_start(self, split_idx: int, opt_idx: int, split_batch) -> None:
self.cached_results._split_idx = split_idx
self.cached_results._opt_idx = opt_idx
self.cached_results._batch_size = Result.extract_batch_size(split_batch)

def on_train_batch_end(self) -> None:
self._cached_results["train"]._split_idx = None
self._cached_results["train"]._opt_idx = None
self._cached_results["train"]._batch_size = None
self.cached_results._split_idx = None
self.cached_results._opt_idx = None
self.cached_results._batch_size = None

def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str:
stage_or_testing = str(stage_or_testing)
Expand All @@ -112,6 +101,16 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps):
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps

@property
def should_flush_logs(self):
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
return should_flush or self.trainer.should_stop

@property
def should_update_logs(self):
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
return should_log_every_n_steps or self.trainer.should_stop

def configure_logger(self, logger):
if logger is True:
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)
Expand All @@ -130,6 +129,53 @@ def configure_logger(self, logger):
else:
self.trainer.logger = logger

def cache_training_step_metrics(self, opt_closure_result):
"""
This function is responsible to update
logger_connector internals metrics holder based for depreceated logging
"""
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)

# temporary dict to collect metrics
logged_metrics_tmp = {}
pbar_metrics_tmp = {}
callback_metrics_tmp = {}

if using_results_obj:
batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics(
include_forked_originals=False
)
logged_metrics_tmp.update(batch_log_metrics)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics(
include_forked_originals=False
)
pbar_metrics_tmp.update(batch_pbar_metrics)

forked_metrics = opt_closure_result.training_step_output.get_forked_metrics()
callback_metrics_tmp.update(forked_metrics)
callback_metrics_tmp.update(logged_metrics_tmp)

else:
batch_log_metrics = opt_closure_result.training_step_output.log_metrics
logged_metrics_tmp.update(batch_log_metrics)

callback_metrics = opt_closure_result.training_step_output.callback_metrics
callback_metrics_tmp.update(callback_metrics)

batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
pbar_metrics_tmp.update(batch_pbar_metrics)

# track progress bar metrics
if len(pbar_metrics_tmp) > 0:
self.add_progress_bar_metrics(pbar_metrics_tmp)

self.callback_metrics.update(callback_metrics_tmp)

# save legacy log metrics
self.logged_metrics.update(logged_metrics_tmp)
self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)

def log_metrics(self, metrics, grad_norm_dic, step=None):
"""Logs the metric dict passed in.
If `step` parameter is None and `step` key is presented is metrics,
Expand Down Expand Up @@ -396,8 +442,9 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
if num_loaders == 1:
self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics, callback_metrics)

def on_train_epoch_end(self, epoch_output):
pass
def on_train_epoch_end(self):
# inform cached logger connector epoch finished
self.cached_results.has_batch_loop_finished = True

def log_train_epoch_end_metrics(self,
epoch_output,
Expand Down Expand Up @@ -441,12 +488,10 @@ def log_train_epoch_end_metrics(self,
# ------------------
if is_1_0_result:
# lightning module hook
epoch_end_log_result = self.training_epoch_end(model, epoch_output, num_optimizers)
self.training_epoch_end(model, epoch_output, num_optimizers)

# log/aggregate metrics automatically
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)
epoch_log_metrics.update(epoch_end_log_result.get_epoch_log_metrics())
epoch_progress_bar_metrics.update(epoch_end_log_result.get_epoch_pbar_metrics())

# TODO: deprecate 1.0
else:
Expand All @@ -459,6 +504,14 @@ def log_train_epoch_end_metrics(self,
)
epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out

# it will perform reduction over epoch and return log metrics
cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics()

# update
epoch_log_metrics.update(cached_epoch_log_metrics)
epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics)

# --------------------------
# track results
# --------------------------
Expand All @@ -475,15 +528,16 @@ def log_train_epoch_end_metrics(self,
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
self.callback_metrics.update(epoch_progress_bar_metrics)

# reset epoch loop result for next epoch
self.cached_results.reset()

def training_epoch_end(self, model, epoch_output, num_optimizers):
if not is_overridden('training_epoch_end', model=model):
return Result()
return

# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = 'training_epoch_end'
model._results = Result()

epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
Expand All @@ -492,15 +546,11 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
# lightningmodule hook
epoch_output = model.training_epoch_end(epoch_output)

model._current_fx_name = ''

if epoch_output is not None:
raise MisconfigurationException('training_epoch_end expects a return of None. '
'HINT: remove the return statement in training_epoch_end')

# user can ALSO log at the end of an epoch
new_epoch_end_logs = model._results
return new_epoch_end_logs
# capture logging
self.trainer.logger_connector.cache_logged_metrics()

def __run_legacy_training_epoch_end(
self,
Expand All @@ -527,8 +577,12 @@ def __run_legacy_training_epoch_end(

# run training_epoch_end
# a list with a result per optimizer index
model._current_fx_name = 'training_epoch_end'
epoch_output = model.training_epoch_end(epoch_output)

# capture logging
self.trainer.logger_connector.cache_logged_metrics()

if isinstance(epoch_output, Result):
epoch_log_metrics = epoch_output.epoch_log_metrics
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
Expand Down Expand Up @@ -563,7 +617,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output):
# reduce across training steps
opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs)

# with manual opt need 1+ metrics because meta is always there
# with manual opt need 1 + metrics because meta is always there
if opt_outputs.minimize is not None:
opt_outputs.minimize = opt_outputs.minimize.mean()
epoch_log_metrics.update(opt_outputs.epoch_log_metrics)
Expand Down Expand Up @@ -623,12 +677,9 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):

def log_train_step_metrics(self, batch_output):
# when metrics should be logged
should_log_metrics = (
(self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 or self.trainer.should_stop
)
if should_log_metrics or self.trainer.fast_dev_run:
if self.should_update_logs or self.trainer.fast_dev_run:
# logs user requested information to logger
metrics = batch_output.batch_log_metrics
metrics = self.cached_results.get_latest_batch_log_metrics()
grad_norm_dic = batch_output.grad_norm_dic
if len(metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(metrics, grad_norm_dic)
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ def __log_result_step_metrics(self, output, batch_idx):
step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False)
step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False)

cached_batch_log_metrics = \
self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics()

if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
Expand Down
24 changes: 23 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,25 @@ def call_setup_hook(self, model):
self.setup(stage_name)
model.setup(stage_name)

def _reset_result_and_set_hook_fx_name(self, hook_name):
model_ref = self.get_model()
if model_ref is not None:
# used to track current hook name called
model_ref._results = Result()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
model_ref._current_hook_fx_name = hook_name

def _cache_logged_metrics(self):
model_ref = self.get_model()
if model_ref is not None:
# capture logging for this hook
self.logger_connector.cache_logged_metrics()

def call_hook(self, hook_name, *args, **kwargs):
# temporary. Don't modify evaluation behaviour
if self.logger_connector._current_stage == "train":
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# set hook_name to model + reset Result obj
self._reset_result_and_set_hook_fx_name(hook_name)

# always profile hooks
with self.profiler.profile(hook_name):

Expand All @@ -860,4 +878,8 @@ def call_hook(self, hook_name, *args, **kwargs):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
output = accelerator_hook(*args, **kwargs)

return output
# temporary. Don't modify evaluation behaviour
if self.logger_connector._current_stage == "train":
# capture logging
self._cache_logged_metrics()
return output
Loading