From 63d7535ffb4a20e4d577cd0718234183d898241e Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 09:50:55 +0000 Subject: [PATCH 01/15] wip --- .../logger_connector/logger_connector.py | 3 +- pytorch_lightning/trainer/evaluation_loop.py | 34 +- pytorch_lightning/trainer/trainer.py | 55 +-- .../test_eval_loop_logging_1_0.py | 360 ++++++++++++++++++ 4 files changed, 402 insertions(+), 50 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 946064660f818..d2507ff618b36 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -59,10 +59,9 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc on_epoch=on_epoch) def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders): - # reset the result of the PL module model = self.trainer.get_model() + # set dataloader_idx only if multiple ones model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - # track batch_size self.cached_results._batch_size = Result.extract_batch_size(batch) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 6ebab1ade0f1d..948714b795bdf 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -169,12 +169,17 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): # configure args args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) + model_ref = self.trainer.get_model() # run actual test step if self.testing: + model_ref._current_fx_name = "test_step" output = self.trainer.accelerator_backend.test_step(args) else: + model_ref._current_fx_name = "validation_step" output = self.trainer.accelerator_backend.validation_step(args) + # capture any logged information + self.trainer.logger_connector.cache_logged_metrics() # track batch size for weighted average is_result_obj = isinstance(output, Result) if is_result_obj: @@ -236,22 +241,22 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): if self.testing: if is_overridden('test_epoch_end', model=model): - model._current_fx_name = 'test_epoch_end' if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) - + model._current_fx_name = 'test_epoch_end' eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if is_overridden('validation_epoch_end', model=model): - model._current_fx_name = 'validation_epoch_end' if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) - + model._current_fx_name = 'validation_epoch_end' eval_results = model.validation_epoch_end(eval_results) user_reduced = True + # capture logging + self.trainer.logger_connector.cache_logged_metrics() # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' @@ -266,6 +271,9 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): if not isinstance(eval_results, list): eval_results = [eval_results] + # support depreceated tracking metrics + self.trainer.logger_connector.track_metrics_deprecated(eval_results, using_eval_result, self.testing) + return eval_results def __gather_epoch_end_eval_results(self, outputs): @@ -299,12 +307,7 @@ def __auto_reduce_result_objs(self, outputs): return eval_results def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): - # reset the result of the PL module - model = self.trainer.get_model() - model._results = Result() - model._current_fx_name = 'evaluation_step' - - # set dataloader_idx and track batch_size + # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( self.testing, batch, dataloader_idx, self.num_dataloaders) @@ -313,13 +316,16 @@ def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) - def on_evaluation_batch_end(self, *args, **kwargs): + def on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx): if self.testing: - self.trainer.call_hook('on_test_batch_end', *args, **kwargs) + self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) else: - self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) + self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) + + # store predicitons if do_write_predictions and track eval loss history + self.store_predictions(output, batch_idx, dataloader_idx) - def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): + def store_predictions(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: do_write_predictions = isinstance(output, Result) and self.testing diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2d4e2c0d9e4bd..9c8e52b38e1ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -237,6 +237,8 @@ def __init__( num_nodes: number of GPU nodes for distributed training. + num_processes: number of processes for distributed training with distributed_backend="ddp_cpu" + num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: 2 @@ -565,14 +567,12 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): self.evaluation_loop.setup(model, max_batches, dataloaders) # hook - # TODO: should this be insider the dataloader loop? self.evaluation_loop.on_evaluation_epoch_start() # run validation/testing for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] - dl_step_metrics = [] dataloader = self.accelerator_backend.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] @@ -591,47 +591,38 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) output = self.evaluation_loop.evaluation_step_end(output) - # hook + # hook + store predictions self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) - # clean up - self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) - - # TODO: deprecate 1.0 - self.evaluation_loop.log_evaluation_step_metrics_legacy(output, batch_idx) - - # log step metrics - step_metrics = self.evaluation_loop.log_evaluation_step_metrics(batch, batch_idx) - - if step_metrics is not None: - dl_step_metrics.append(step_metrics) + # log batch metrics + self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx) # track epoch level outputs if output is not None: dl_outputs.append(output) + # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) - self.evaluation_loop.step_metrics.append(dl_step_metrics) - # lightning module method - deprecated_eval_results, epoch_logs = self.evaluation_loop.evaluation_epoch_end( - num_dataloaders=len(dataloaders) - ) - - # bookkeeping - eval_loop_results = self.evaluation_loop.log_epoch_metrics(deprecated_eval_results, epoch_logs, test_mode) - self.evaluation_loop.predictions.to_disk() + # lightning module method + inform logger batch loop finished + deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() # hook self.evaluation_loop.on_evaluation_epoch_end() + # hook + self.evaluation_loop.on_evaluation_end() + + # bookkeeping + eval_loop_results = self.evaluation_loop.log_epoch_metrics_on_evaluation_end() + + # save predictions to disk + self.evaluation_loop.predictions.to_disk() + # enable train mode again self.evaluation_loop.on_evaluation_model_train() torch.set_grad_enabled(True) - # hook - self.evaluation_loop.on_evaluation_end() - return eval_loop_results, deprecated_eval_results def run_test(self): @@ -852,10 +843,8 @@ def _cache_logged_metrics(self): self.logger_connector.cache_logged_metrics() def call_hook(self, hook_name, *args, **kwargs): - # temporary. Don't modify evaluation behaviour - if self.logger_connector._current_stage == "train": - # set hook_name to model + reset Result obj - self._reset_result_and_set_hook_fx_name(hook_name) + # set hook_name to model + reset Result obj + self._reset_result_and_set_hook_fx_name(hook_name) # always profile hooks with self.profiler.profile(hook_name): @@ -878,8 +867,6 @@ def call_hook(self, hook_name, *args, **kwargs): accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) - # temporary. Don't modify evaluation behaviour - if self.logger_connector._current_stage == "train": - # capture logging - self._cache_logged_metrics() + # capture logging + self._cache_logged_metrics() return output diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index 0f3217b3f004c..d9ae1a89d631a 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -14,12 +14,15 @@ """ Tests to ensure that the training loop works with a dict (1.0) """ +import pytorch_lightning as pl from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning import Trainer from pytorch_lightning import callbacks, seed_everything from tests.base.deterministic_model import DeterministicModel from tests.base import SimpleModule, BoringModel, RandomDataset import os +import itertools +import collections import torch import pytest @@ -419,3 +422,360 @@ def test_dataloader(self): assert len(results) == 1 # error : It is wrong there. `y` should equal test_loss_epoch assert results[0]['test_loss'] == results[0]['y'] + + +def test_log_works_in_val_callback(tmpdir): + """ + Tests that log can be called within callback + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + self.funcs_called_count[func_name] += 1 + for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + # run logging + custom_func_name = f"{func_idx}_{idx}_{func_name}" + pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + # catch information for verification + self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.funcs_attr[custom_func_name] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "forked": on_step and on_epoch, + "func_name": func_name} + + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step"] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name} + + self.funcs_attr[f"{custom_func_name}_epoch"] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name} + + def on_validation_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_validation_start', 1, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_start', 2, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_validation_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_validation_epoch_start', 3, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_validation_batch_start', 5, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_validation_batch_end', 7, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_validation_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + def on_validation_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_validation_epoch_end', 9, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('val_loss', loss) + + max_epochs = 1 + model = TestModel() + model.validation_epoch_end = None + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=4, + limit_test_batches=0, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback], + ) + trainer.fit(model) + trainer.test() + + assert test_callback.funcs_called_count["on_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_batch_start"] == 1 + assert test_callback.funcs_called_count["on_batch_end"] == 1 + assert test_callback.funcs_called_count["on_validation_start"] == 1 + assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 + assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 + assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1 + assert test_callback.funcs_called_count["on_epoch_end"] == 1 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + + if func_attr["on_epoch"] and not func_attr["on_step"]: + # Apply mean on values + expected_output = np.mean(original_values) + else: + # Keep the latest value + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + trainer.callback_metrics.pop("debug_epoch") + trainer.callback_metrics.pop("val_loss") + for func_name, output_value in trainer.callback_metrics.items(): + # not sure how to handle this now + if "epoch_0" in func_name: + func_name = '/'.join(func_name.split('/')[:-1]) + continue + + if torch.is_tensor(output_value): + output_value = output_value.item() + # get creation attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics + + +def test_log_works_in_test_callback(tmpdir): + """ + Tests that log can be called within callback + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + original_func_name = func_name[:] + self.funcs_called_count[original_func_name] += 1 + for idx, t in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + # run logging + func_name = original_func_name[:] + on_step, on_epoch, prog_bar = t + custom_func_name = f"{func_idx}_{idx}_{func_name}" + + pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + + num_dl_ext = '' + if pl_module._current_dataloader_idx is not None: + dl_idx = pl_module._current_dataloader_idx + num_dl_ext = f"/dataloader_idx_{dl_idx}" + func_name += num_dl_ext + + # catch information for verification + self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.funcs_attr[custom_func_name + num_dl_ext] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "forked": on_step and on_epoch, + "func_name": func_name} + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step" + num_dl_ext] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name} + + self.funcs_attr[f"{custom_func_name}_epoch" + num_dl_ext] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name} + + def on_test_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_start', 2, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_test_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_test_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + def on_test_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + max_epochs = 2 + num_dataloaders = 2 + + class TestModel(BoringModel): + + manual_mean = collections.defaultdict(list) + + def test_step(self, batch, batch_idx, dataloader_idx=None): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('test_loss', loss) + self.manual_mean[str(dataloader_idx)].append(loss) + + def test_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] + + model = TestModel() + model.test_epoch_end = None + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=0, + limit_test_batches=2, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback], + ) + trainer.fit(model) + trainer.test() + + assert test_callback.funcs_called_count["on_test_start"] == 1 + assert test_callback.funcs_called_count["on_epoch_start"] == 2 + assert test_callback.funcs_called_count["on_test_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_test_batch_start"] == 4 + assert test_callback.funcs_called_count["on_test_batch_end"] == 4 + assert test_callback.funcs_called_count["on_epoch_end"] == 2 + assert test_callback.funcs_called_count["on_test_epoch_end"] == 1 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + # Apply mean on values + if func_attr["on_epoch"] and not func_attr["on_step"]: + expected_output = np.mean(original_values) + else: + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + assert "debug_epoch" in trainer.callback_metrics + trainer.callback_metrics.pop("debug_epoch") + + for dl_idx in range(num_dataloaders): + key = f"test_loss/dataloader_idx_{dl_idx}" + assert key in trainer.callback_metrics + assert torch.stack(model.manual_mean[str(dl_idx)]).mean() == trainer.callback_metrics[key] + trainer.callback_metrics.pop(key) + + for func_name, output_value in trainer.callback_metrics.items(): + # not sure how to handle this now + if "epoch_1" in func_name: + func_name = '/'.join(func_name.split('/')[:-1]) + continue + + if torch.is_tensor(output_value): + output_value = output_value.item() + + # get func attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics From 5eaa330e35fbb67e77e7e371067442de223cea48 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 11:25:44 +0000 Subject: [PATCH 02/15] wip check how many tests break --- .../logger_connector/logger_connector.py | 15 +++++- pytorch_lightning/trainer/evaluation_loop.py | 47 ++++++++++--------- .../test_eval_loop_logging_1_0.py | 3 +- 3 files changed, 40 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index d2507ff618b36..1df2c7bc3eaa4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -223,6 +223,19 @@ def add_progress_bar_metrics(self, metrics): self.trainer.dev_debugger.track_pbar_metrics_history(metrics) + def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode): + self._track_callback_metrics(deprecated_eval_results, using_eval_result) + self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode) + + def evaluation_epoch_end(self, testing): + # reset dataloader idx + model_ref = self.trainer.get_model() + model_ref._current_dataloader_idx = None + + # setting `has_batch_loop_finished` to True + # will perform Results reduction accross entire epoch. + self.cached_results.has_batch_loop_finished = True + def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode): self._track_callback_metrics(deprecated_eval_results, using_eval_result) @@ -235,7 +248,7 @@ def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eva eval_loop_results = self._get_evaluate_epoch_results(test_mode) return eval_loop_results - def _get_evaluate_epoch_results(self, test_mode): + def get_evaluate_epoch_results(self, test_mode): # log results of test if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: print('-' * 80) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 948714b795bdf..0ca30285c59c6 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -199,21 +199,21 @@ def evaluation_step_end(self, *args, **kwargs): output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output - def evaluation_epoch_end(self, num_dataloaders): + def evaluation_epoch_end(self): + # inform logger_connector batch loop is finished + self.trainer.logger_connector.evaluation_epoch_end(self.testing) + using_eval_result = self.is_using_eval_results() # call the model epoch end - deprecated_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) - - # 1.0 - epoch_logs = self.trainer.get_model()._results + deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders, using_eval_result) # enable returning anything for i, r in enumerate(deprecated_results): if not isinstance(r, (dict, Result, torch.Tensor)): deprecated_results[i] = [] - return deprecated_results, epoch_logs + return deprecated_results def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode): using_eval_result = self.is_using_eval_results() @@ -225,6 +225,11 @@ def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode): ) return eval_loop_results + def log_epoch_metrics_on_evaluation_end(self): + # get the final loop results + eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing) + return eval_loop_results + def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() @@ -271,7 +276,7 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): if not isinstance(eval_results, list): eval_results = [eval_results] - # support depreceated tracking metrics + # track depreceated metrics self.trainer.logger_connector.track_metrics_deprecated(eval_results, using_eval_result, self.testing) return eval_results @@ -342,30 +347,26 @@ def on_evaluation_epoch_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) - def log_evaluation_step_metrics(self, batch, batch_idx): - results = self.trainer.get_model()._results - if len(results) == 1: - return None - - results.track_batch_size(batch) - self.__log_result_step_metrics(results, batch_idx) - - return results - - # TODO: deprecate at 1.0 - def log_evaluation_step_metrics_legacy(self, output, batch_idx): + def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: return + step_log_metrics = {} + step_pbar_metrics = {} if isinstance(output, EvalResult): - self.__log_result_step_metrics(output, batch_idx) + step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False) + step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False) - def __log_result_step_metrics(self, output, batch_idx): - step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False) - step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False) + self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx) + def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx): cached_batch_log_metrics = \ self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics() + cached_batch_pbar_metrics = \ + self.trainer.logger_connector.cached_results.get_latest_batch_pbar_metrics() + + step_log_metrics.update(cached_batch_log_metrics) + step_pbar_metrics.update(cached_batch_pbar_metrics) if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index d9ae1a89d631a..b417db7893779 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -21,6 +21,7 @@ from tests.base.deterministic_model import DeterministicModel from tests.base import SimpleModule, BoringModel, RandomDataset import os +import numpy as np import itertools import collections import torch @@ -524,7 +525,7 @@ def validation_step(self, batch, batch_idx): trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=0, + limit_train_batches=1, limit_val_batches=4, limit_test_batches=0, val_check_interval=0., From c269a8bd08e78256b4311a5b3bc48c98282d1451 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 13:15:24 +0000 Subject: [PATCH 03/15] wip --- .../logger_connector/epoch_result_store.py | 6 +++- pytorch_lightning/trainer/evaluation_loop.py | 5 +-- pytorch_lightning/trainer/trainer.py | 4 +-- tests/models/test_hooks.py | 2 +- .../trainer/logging/test_logger_connector.py | 6 +++- .../test_eval_loop_logging_1_0.py | 34 ++++++++++--------- 6 files changed, 32 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 2980b037c95f7..6011d1a64308e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from copy import deepcopy from collections import defaultdict, ChainMap from enum import Enum from typing import Union, Tuple, Any, Dict, Optional, List @@ -415,13 +416,14 @@ def update_logger_connector(self, fx_name: str = None) -> None: logger_connector = self.trainer.logger_connector callback_metrics = {} + is_train = self._stage in LoggerStages.TRAIN.value if not self._has_batch_loop_finished: # get pbar batch_pbar_metrics = self.get_latest_batch_pbar_metrics() logger_connector.add_progress_bar_metrics(batch_pbar_metrics) - if self._stage in LoggerStages.TRAIN.value: + if is_train: # Only log and add to callback epoch step during evaluation, test. batch_log_metrics = self.get_latest_batch_log_metrics() logger_connector.logged_metrics.update(batch_log_metrics) @@ -439,6 +441,8 @@ def update_logger_connector(self, fx_name: str = None) -> None: epoch_log_metrics = self.get_epoch_log_metrics() logger_connector.logged_metrics.update(epoch_log_metrics) logger_connector.logged_metrics.update(epoch_dict) + if not self.trainer.running_sanity_check and not is_train: + self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics)) # get forked_metrics forked_metrics = self.get_forked_metrics() diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 0ca30285c59c6..2b1d78ca9bf77 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -109,9 +109,6 @@ def on_evaluation_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_end', *args, **kwargs) - # reset stage to train - self.trainer.logger_connector.set_stage("train") - def reload_evaluation_dataloaders(self): model = self.trainer.get_model() if self.testing: @@ -200,7 +197,7 @@ def evaluation_step_end(self, *args, **kwargs): return output def evaluation_epoch_end(self): - # inform logger_connector batch loop is finished + # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end(self.testing) using_eval_result = self.is_using_eval_results() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9c8e52b38e1ca..6c8f8d16d8557 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -604,7 +604,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) - # lightning module method + inform logger batch loop finished + # lightning module method deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() # hook @@ -613,7 +613,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # hook self.evaluation_loop.on_evaluation_end() - # bookkeeping + # log epoch metrics eval_loop_results = self.evaluation_loop.log_epoch_metrics_on_evaluation_end() # save predictions to disk diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index bccc5262a5bda..f3af5b745a380 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -333,8 +333,8 @@ def on_test_model_train(self): 'on_validation_batch_start', 'on_validation_batch_end', 'on_validation_epoch_end', - 'on_validation_model_train', 'on_save_checkpoint', + 'on_validation_model_train', 'on_epoch_end', 'on_train_epoch_end', 'on_train_end', diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 08936f89eb9f8..f681691cd7105 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -240,6 +240,10 @@ def test_step(self, batch, batch_idx, dl_idx=0): self.log("test_loss", loss, on_step=True, on_epoch=True) return {"test_loss": loss} + def on_test_epoch_end(self): + # save objects as it will be reset at the end of epoch. + self.test_results = deepcopy(self.trainer.logger_connector.cached_results) + def test_dataloader(self): return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] @@ -260,7 +264,7 @@ def test_dataloader(self): ) trainer.test(model) - test_results = trainer.logger_connector._cached_results["test"] + test_results = model.test_results generated = test_results(fx_name="test_step") assert len(generated) == num_dataloaders diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index b417db7893779..9b40d199fbca5 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -303,20 +303,16 @@ def validation_epoch_end(self, outputs) -> None: # make sure correct values were logged logged_val = trainer.dev_debugger.logged_metrics - # sanity check - assert logged_val[0]['global_step'] == 0 - assert logged_val[1]['global_step'] == 0 - # 3 val batches - assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[0] - assert logged_val[3]['val_loss_step/epoch_0'] == model.seen_vals[1] - assert logged_val[4]['val_loss_step/epoch_0'] == model.seen_vals[2] + assert logged_val[1]['val_loss_step/epoch_0'] == model.seen_vals[0] + assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[1] + assert logged_val[3]['val_loss_step/epoch_0'] == model.seen_vals[2] # epoch mean - assert logged_val[5]['val_loss_epoch'] == model.manual_epoch_end_mean + assert logged_val[4]['val_loss_epoch'] == model.manual_epoch_end_mean # only those logged - assert len(logged_val) == 6 + assert len(logged_val) == 4 @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) @@ -324,7 +320,7 @@ def test_eval_epoch_only_logging(tmpdir, batches, log_interval, max_epochs): """ Tests that only test_epoch_end can be used to log, and we return them in the results. """ - os.environ['PL_DEV_DEBUG'] = '1' + os.environ['PL_DEV_DEBUG'] = '0' class TestModel(BoringModel): def test_epoch_end(self, outputs): @@ -441,12 +437,15 @@ class TestCallback(callbacks.Callback): funcs_called_count = collections.defaultdict(int) funcs_attr = {} - def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + def make_logging(self, pl_module: pl.LightningModule, func_name, + func_idx, on_steps=[], on_epochs=[], prob_bars=[]): self.funcs_called_count[func_name] += 1 - for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + product = [on_steps, on_epochs, prob_bars] + for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*product))): # run logging custom_func_name = f"{func_idx}_{idx}_{func_name}" - pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + pl_module.log(custom_func_name, self.count * func_idx, + on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) # catch information for verification self.callback_funcs_called[func_name].append([self.count * func_idx]) self.funcs_attr[custom_func_name] = { @@ -612,16 +611,19 @@ class TestCallback(callbacks.Callback): funcs_called_count = collections.defaultdict(int) funcs_attr = {} - def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + def make_logging(self, pl_module: pl.LightningModule, func_name, + func_idx, on_steps=[], on_epochs=[], prob_bars=[]): original_func_name = func_name[:] self.funcs_called_count[original_func_name] += 1 - for idx, t in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + product = [on_steps, on_epochs, prob_bars] + for idx, t in enumerate(list(itertools.product(*product))): # run logging func_name = original_func_name[:] on_step, on_epoch, prog_bar = t custom_func_name = f"{func_idx}_{idx}_{func_name}" - pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + pl_module.log(custom_func_name, self.count * func_idx, + on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) num_dl_ext = '' if pl_module._current_dataloader_idx is not None: From 2444640769c0e4643ad828764f4c15123336a870 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 13:34:53 +0000 Subject: [PATCH 04/15] resolve some bugs --- .../connectors/logger_connector/epoch_result_store.py | 3 ++- .../connectors/logger_connector/logger_connector.py | 10 ++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 6011d1a64308e..c68148dc9e5d1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -442,7 +442,8 @@ def update_logger_connector(self, fx_name: str = None) -> None: logger_connector.logged_metrics.update(epoch_log_metrics) logger_connector.logged_metrics.update(epoch_dict) if not self.trainer.running_sanity_check and not is_train: - self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics)) + if len(epoch_log_metrics) > 0: + self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics)) # get forked_metrics forked_metrics = self.get_forked_metrics() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 1df2c7bc3eaa4..ca870f3b6818d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -232,23 +232,21 @@ def evaluation_epoch_end(self, testing): model_ref = self.trainer.get_model() model_ref._current_dataloader_idx = None - # setting `has_batch_loop_finished` to True - # will perform Results reduction accross entire epoch. - self.cached_results.has_batch_loop_finished = True - def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode): self._track_callback_metrics(deprecated_eval_results, using_eval_result) # TODO: deprecate parts of this for 1.0 (when removing results) self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode) - self._log_on_evaluation_epoch_end_metrics(epoch_logs) - # get the final loop results eval_loop_results = self._get_evaluate_epoch_results(test_mode) return eval_loop_results def get_evaluate_epoch_results(self, test_mode): + # setting `has_batch_loop_finished` to True + # will perform Results reduction accross entire epoch. + self.cached_results.has_batch_loop_finished = True + # log results of test if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: print('-' * 80) From 372657787c05788247a25e67e10c4488b9de47e9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 13:49:03 +0000 Subject: [PATCH 05/15] resolve more bugs --- .../connectors/logger_connector/logger_connector.py | 7 ++++--- tests/trainer/logging_tests/test_eval_loop_logging_1_0.py | 8 ++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ca870f3b6818d..7fc5ca7882585 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -232,6 +232,10 @@ def evaluation_epoch_end(self, testing): model_ref = self.trainer.get_model() model_ref._current_dataloader_idx = None + # setting `has_batch_loop_finished` to True + # will perform Results reduction accross entire epoch. + self.cached_results.has_batch_loop_finished = True + def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode): self._track_callback_metrics(deprecated_eval_results, using_eval_result) @@ -243,9 +247,6 @@ def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eva return eval_loop_results def get_evaluate_epoch_results(self, test_mode): - # setting `has_batch_loop_finished` to True - # will perform Results reduction accross entire epoch. - self.cached_results.has_batch_loop_finished = True # log results of test if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index 9b40d199fbca5..89fabc76353ac 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -304,12 +304,12 @@ def validation_epoch_end(self, outputs) -> None: logged_val = trainer.dev_debugger.logged_metrics # 3 val batches - assert logged_val[1]['val_loss_step/epoch_0'] == model.seen_vals[0] - assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[1] - assert logged_val[3]['val_loss_step/epoch_0'] == model.seen_vals[2] + assert logged_val[0]['val_loss_step/epoch_0'] == model.seen_vals[0] + assert logged_val[1]['val_loss_step/epoch_0'] == model.seen_vals[1] + assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[2] # epoch mean - assert logged_val[4]['val_loss_epoch'] == model.manual_epoch_end_mean + assert logged_val[3]['val_loss_epoch'] == model.manual_epoch_end_mean # only those logged assert len(logged_val) == 4 From 1855fa976ec91fe503d71bb83f1323f84532303a Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 14:02:04 +0000 Subject: [PATCH 06/15] resolve 2 bugs --- tests/trainer/logging/test_logger_connector.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index f681691cd7105..38a7dee896a8c 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -240,9 +240,13 @@ def test_step(self, batch, batch_idx, dl_idx=0): self.log("test_loss", loss, on_step=True, on_epoch=True) return {"test_loss": loss} + def on_test_batch_end(self, *args, **kwargs): + # save objects as it will be reset at the end of epoch. + self.batch_results = deepcopy(self.trainer.logger_connector.cached_results) + def on_test_epoch_end(self): # save objects as it will be reset at the end of epoch. - self.test_results = deepcopy(self.trainer.logger_connector.cached_results) + self.reduce_results = deepcopy(self.trainer.logger_connector.cached_results) def test_dataloader(self): return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] @@ -264,7 +268,7 @@ def test_dataloader(self): ) trainer.test(model) - test_results = model.test_results + test_results = model.batch_results generated = test_results(fx_name="test_step") assert len(generated) == num_dataloaders @@ -273,7 +277,7 @@ def test_dataloader(self): generated = len(test_results(fx_name="test_step", dl_idx=str(dl_idx))) assert generated == limit_test_batches - test_results.has_batch_loop_finished = True + test_results = model.reduce_results for dl_idx in range(num_dataloaders): expected = torch.stack(model.test_losses[str(dl_idx)]).mean() From a83a3130d75e5415a008529e036d89c7f223088c Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 15:18:45 +0000 Subject: [PATCH 07/15] resolve --- .../logger_connector/logger_connector.py | 36 ++++++++++--------- .../test_eval_loop_logging_1_0.py | 1 - 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 7fc5ca7882585..841333a328620 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -246,8 +246,28 @@ def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eva eval_loop_results = self._get_evaluate_epoch_results(test_mode) return eval_loop_results + def add_to_eval_loop_results(self, dl_idx, has_been_initialized): + callback_metrics = deepcopy(self.callback_metrics) + for key in list(callback_metrics.keys()): + if "dataloader_idx" in key: + if f"dataloader_idx_{dl_idx}" not in key: + # remove dl_idx from self.callback_metrics not belonging to this dataset. + del callback_metrics[key] + if has_been_initialized: + self.eval_loop_results[dl_idx].update(callback_metrics) + else: + self.eval_loop_results.append(callback_metrics) + + def prepare_eval_loop_results(self): + num_dataloaders = self.trainer.evaluation_loop.num_dataloaders + has_been_initialized = len(self.eval_loop_results) == num_dataloaders + for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): + self.add_to_eval_loop_results(dl_idx, has_been_initialized) + def get_evaluate_epoch_results(self, test_mode): + self.prepare_eval_loop_results() + # log results of test if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: print('-' * 80) @@ -339,22 +359,6 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs): if len(metrics_to_log) > 0: self.log_metrics(metrics_to_log, {}) - def add_to_eval_loop_results(self, dl_idx, num_loaders): - callback_metrics = deepcopy(self.callback_metrics) - if num_loaders == 1: - if len(self.eval_loop_results) > 0: - self.eval_loop_results[0].update(callback_metrics) - else: - self.eval_loop_results.append(callback_metrics) - return - - for key in list(callback_metrics.keys()): - if "dataloader_idx" in key: - if f"dataloader_idx_{dl_idx}" not in key: - # remove dl_idx from self.callback_metrics not belonging to this dataset. - del callback_metrics[key] - self.eval_loop_results.append(callback_metrics) - def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders): if num_loaders == 1: return metrics diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index 89fabc76353ac..152d5b5f607c6 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -386,7 +386,6 @@ def test_dataloader(self): weights_summary=None, ) results = trainer.test(model) - assert len(results[0]) == len(results[1]) assert "test_loss_epoch/dataloader_idx_0" in results[0] assert "test_loss_epoch/dataloader_idx_1" in results[1] From fec1b20d459523ccff6e864f8540faadf95b4c6c Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 15:41:09 +0000 Subject: [PATCH 08/15] temp fix --- .../test_eval_loop_dict_return.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py index 6329480e10a11..2902309ebb413 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py @@ -44,7 +44,7 @@ def backward(self, loss, optimizer, optimizer_idx): # out are the results of the full loop # eval_results are output of _evaluate out, eval_results = trainer.run_evaluation(test_mode=False) - assert len(out) == 0 + assert len(out) == 1 assert len(eval_results) == 0 # make sure correct steps were called @@ -75,7 +75,7 @@ def test_validation_step_scalar_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate out, eval_results = trainer.run_evaluation(test_mode=False) - assert len(out) == 0 + assert len(out) == 1 assert len(eval_results) == 2 assert eval_results[0] == 171 and eval_results[1] == 171 @@ -148,7 +148,7 @@ def test_validation_step_dict_return(tmpdir): # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 5 + assert len(callback_metrics[0]) == 7 assert len(eval_results) == 2 assert eval_results[0]['log']['log_acc1'] == 12 assert eval_results[1]['log']['log_acc1'] == 13 @@ -225,7 +225,7 @@ def test_val_step_step_end(tmpdir): # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 6 + assert len(callback_metrics[0]) == 8 callback_metrics = callback_metrics[0] assert callback_metrics['val_step_end'] == 1802 @@ -273,7 +273,7 @@ def test_no_val_step_end(tmpdir): # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 6 + assert len(callback_metrics[0]) == 8 assert len(eval_results) == 1 eval_results = eval_results[0] @@ -319,7 +319,7 @@ def test_full_val_loop(tmpdir): # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 7 + assert len(callback_metrics[0]) == 9 assert len(eval_results) == 1 eval_results = eval_results[0] From 6038412356f50ca934f7ab0e45cb62b1ffed2df0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 15:53:20 +0000 Subject: [PATCH 09/15] update --- .../test_eval_loop_dict_return.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py index 2902309ebb413..205a5405307de 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py @@ -190,7 +190,7 @@ def test_val_step_step_end_no_return(tmpdir): # out are the results of the full loop # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) - assert len(callback_metrics) == 0 + assert len(callback_metrics) == 1 assert len(eval_results) == 0 # make sure correct steps were called From a7a44ce1d187333ec53c38820353ce1dfd2b5692 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 16:08:50 +0000 Subject: [PATCH 10/15] remove useless code --- .../logger_connector/logger_connector.py | 94 ------------------- pytorch_lightning/trainer/evaluation_loop.py | 10 -- 2 files changed, 104 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 841333a328620..a429798b61851 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -236,16 +236,6 @@ def evaluation_epoch_end(self, testing): # will perform Results reduction accross entire epoch. self.cached_results.has_batch_loop_finished = True - def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode): - self._track_callback_metrics(deprecated_eval_results, using_eval_result) - - # TODO: deprecate parts of this for 1.0 (when removing results) - self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode) - - # get the final loop results - eval_loop_results = self._get_evaluate_epoch_results(test_mode) - return eval_loop_results - def add_to_eval_loop_results(self, dl_idx, has_been_initialized): callback_metrics = deepcopy(self.callback_metrics) for key in list(callback_metrics.keys()): @@ -282,90 +272,6 @@ def get_evaluate_epoch_results(self, test_mode): self.eval_loop_results = [] return results - def _log_on_evaluation_epoch_end_metrics(self, epoch_logs): - step_metrics = self.trainer.evaluation_loop.step_metrics - - num_loaders = len(step_metrics) - - # clear mem - self.trainer.evaluation_loop.step_metrics = [] - - if self.trainer.running_sanity_check: - return - - # track all metrics we want to log - metrics_to_log = [] - - # --------------------------- - # UPDATE EPOCH LOGGED METRICS - # --------------------------- - # (ie: in methods at the val_epoch_end level) - # union the epoch logs with whatever was returned from loaders and reduced - epoch_logger_metrics = epoch_logs.get_epoch_log_metrics() - epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics() - - self.logged_metrics.update(epoch_logger_metrics) - self.add_progress_bar_metrics(epoch_pbar_metrics) - - # enable the metrics to be monitored - self.callback_metrics.update(epoch_logger_metrics) - self.callback_metrics.update(epoch_pbar_metrics) - - if len(epoch_logger_metrics) > 0: - metrics_to_log.append(epoch_logger_metrics) - - # -------------------------------- - # UPDATE METRICS PER DATALOADER - # -------------------------------- - # each dataloader aggregated metrics - # now we log all of them - for dl_idx, dl_metrics in enumerate(step_metrics): - if len(dl_metrics) == 0: - # Ensure custom logged metrics are included if not included with step metrics - if len(epoch_logger_metrics) > 0: - self.eval_loop_results.append(epoch_logger_metrics) - continue - - reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics) - # track the metrics - logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics() - pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics() - forked_metrics = reduced_epoch_metrics.get_forked_metrics() - - # make the keys 'k/dl' - logger_metrics = self.__rename_keys_by_dataloader_idx(logger_metrics, dl_idx, num_loaders) - pbar_metrics = self.__rename_keys_by_dataloader_idx(pbar_metrics, dl_idx, num_loaders) - forked_metrics = self.__rename_keys_by_dataloader_idx(forked_metrics, dl_idx, num_loaders) - - self.logged_metrics.update(logger_metrics) - self.add_progress_bar_metrics(pbar_metrics) - - # enable the metrics to be monitored - self.callback_metrics.update(logger_metrics) - self.callback_metrics.update(pbar_metrics) - - # forked metrics were dropped, enable them for callbacks - self.callback_metrics.update(forked_metrics) - - # track the final results for the dataloader - self.add_to_eval_loop_results(dl_idx, num_loaders) - - # actually log - if len(logger_metrics) > 0: - metrics_to_log.append(logger_metrics) - - # log all the metrics as a s single dict - metrics_to_log = dict(ChainMap(*metrics_to_log)) - if len(metrics_to_log) > 0: - self.log_metrics(metrics_to_log, {}) - - def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders): - if num_loaders == 1: - return metrics - - result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()} - return result - def _track_callback_metrics(self, eval_results, using_eval_result): if ( len(eval_results) > 0 and diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2b1d78ca9bf77..1f16c50e634d9 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -212,16 +212,6 @@ def evaluation_epoch_end(self): return deprecated_results - def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode): - using_eval_result = self.is_using_eval_results() - eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end( - deprecated_eval_results, - epoch_logs, - using_eval_result, - test_mode - ) - return eval_loop_results - def log_epoch_metrics_on_evaluation_end(self): # get the final loop results eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing) From fef917ebb8758680d65c96f9f4c6f1959f949fda Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 16:09:47 +0000 Subject: [PATCH 11/15] remove result --- pytorch_lightning/trainer/evaluation_loop.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1f16c50e634d9..e3a0f1108f1f9 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -220,9 +220,6 @@ def log_epoch_metrics_on_evaluation_end(self): def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() - # reset results - model._results = Result() - # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs From f3e47c93fb89388d8ec9f1bb271048145be05b56 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 16:27:26 +0000 Subject: [PATCH 12/15] try to resolve bug --- .../logger_connector/epoch_result_store.py | 3 +++ .../connectors/logger_connector/logger_connector.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index c68148dc9e5d1..36d5f69f4b59a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -452,6 +452,9 @@ def update_logger_connector(self, fx_name: str = None) -> None: callback_metrics.update(epoch_log_metrics) callback_metrics.update(forked_metrics) + if not is_train: + logger_connector.evaluation_callback_metrics.update(callback_metrics) + # update callback_metrics logger_connector.callback_metrics.update(callback_metrics) logger_connector.callback_metrics.pop("epoch", None) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a429798b61851..1d6f43c5169fc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -36,6 +36,7 @@ class LoggerConnector: def __init__(self, trainer): self.trainer = trainer self.callback_metrics = {} + self.evaluation_callback_metrics = {} self.logged_metrics = {} self.progress_bar_metrics = {} self.eval_loop_results = [] @@ -237,7 +238,7 @@ def evaluation_epoch_end(self, testing): self.cached_results.has_batch_loop_finished = True def add_to_eval_loop_results(self, dl_idx, has_been_initialized): - callback_metrics = deepcopy(self.callback_metrics) + callback_metrics = deepcopy(self.evaluation_callback_metrics) for key in list(callback_metrics.keys()): if "dataloader_idx" in key: if f"dataloader_idx_{dl_idx}" not in key: @@ -283,8 +284,10 @@ def _track_callback_metrics(self, eval_results, using_eval_result): if isinstance(eval_results, list): for eval_result in eval_results: self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics) + self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics) else: self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics) + self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics) else: flat = {} if isinstance(eval_results, list): @@ -300,6 +303,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result): flat['checkpoint_on'] = flat['val_loss'] flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) + self.trainer.logger_connector.evaluation_callback_metrics.update(flat) else: # with a scalar return, auto set it to "val_loss" for callbacks if isinstance(eval_results, torch.Tensor): @@ -312,6 +316,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result): flat['checkpoint_on'] = flat['val_loss'] flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) + self.trainer.logger_connector.evaluation_callback_metrics.update(flat) def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics): # eval loop returns all metrics @@ -325,9 +330,10 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric self.trainer.logger_connector.log_metrics(log_metrics, {}) # track metrics for callbacks (all prog bar, logged and callback metrics) + callback_metrics.update(log_metrics) + callback_metrics.update(prog_bar_metrics) self.trainer.logger_connector.callback_metrics.update(callback_metrics) - self.trainer.logger_connector.callback_metrics.update(log_metrics) - self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics) + self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics) if len(dataloader_result_metrics) > 0: self.eval_loop_results.append(dataloader_result_metrics) From 442d9e79b2c236777eae7ed82b760e5bd0abbbee Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 6 Nov 2020 16:40:56 +0000 Subject: [PATCH 13/15] update changelog --- CHANGELOG.md | 5 ++++- .../connectors/logger_connector/logger_connector.py | 2 +- .../test_eval_loop_dict_return.py | 8 ++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f4defbd5cc30..6b9cc1e356623 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458)) -- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) +- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) + + +- Added logging using `self.log` in train and evaluation for most callbacks and model hooks ([#4552](https://github.com/PyTorchLightning/pytorch-lightning/pull/4552)) ([#4495](https://github.com/PyTorchLightning/pytorch-lightning/pull/4495)) ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) ### Changed diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 1d6f43c5169fc..792a03b1fe0f7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -287,7 +287,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result): self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics) else: self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics) - self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics) + self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics) else: flat = {} if isinstance(eval_results, list): diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py index 205a5405307de..8168f09c68e00 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py @@ -148,7 +148,7 @@ def test_validation_step_dict_return(tmpdir): # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 7 + assert len(callback_metrics[0]) == 5 assert len(eval_results) == 2 assert eval_results[0]['log']['log_acc1'] == 12 assert eval_results[1]['log']['log_acc1'] == 13 @@ -225,7 +225,7 @@ def test_val_step_step_end(tmpdir): # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 8 + assert len(callback_metrics[0]) == 6 callback_metrics = callback_metrics[0] assert callback_metrics['val_step_end'] == 1802 @@ -273,7 +273,7 @@ def test_no_val_step_end(tmpdir): # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 8 + assert len(callback_metrics[0]) == 6 assert len(eval_results) == 1 eval_results = eval_results[0] @@ -319,7 +319,7 @@ def test_full_val_loop(tmpdir): # eval_results are output of _evaluate callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 9 + assert len(callback_metrics[0]) == 7 assert len(eval_results) == 1 eval_results = eval_results[0] From c21e745e512fe3a3ca8572fe94ab9c6123f6dc02 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 6 Nov 2020 19:32:11 +0100 Subject: [PATCH 14/15] formatting --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b9cc1e356623..16594880827c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) -- Added logging using `self.log` in train and evaluation for most callbacks and model hooks ([#4552](https://github.com/PyTorchLightning/pytorch-lightning/pull/4552)) ([#4495](https://github.com/PyTorchLightning/pytorch-lightning/pull/4495)) ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Added logging using `self.log` in train and evaluation for most callbacks and model hooks ([#4552](https://github.com/PyTorchLightning/pytorch-lightning/pull/4552), + [#4495](https://github.com/PyTorchLightning/pytorch-lightning/pull/4495), + [#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) ### Changed From 3100edc5dab14e62d45cdc0271dc3143aaab6934 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 7 Nov 2020 11:54:26 +0000 Subject: [PATCH 15/15] remove pl --- tests/trainer/logging_tests/test_eval_loop_logging_1_0.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index 152d5b5f607c6..12f53328ec98a 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -14,7 +14,6 @@ """ Tests to ensure that the training loop works with a dict (1.0) """ -import pytorch_lightning as pl from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning import Trainer from pytorch_lightning import callbacks, seed_everything @@ -436,7 +435,7 @@ class TestCallback(callbacks.Callback): funcs_called_count = collections.defaultdict(int) funcs_attr = {} - def make_logging(self, pl_module: pl.LightningModule, func_name, + def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): self.funcs_called_count[func_name] += 1 product = [on_steps, on_epochs, prob_bars] @@ -610,7 +609,7 @@ class TestCallback(callbacks.Callback): funcs_called_count = collections.defaultdict(int) funcs_attr = {} - def make_logging(self, pl_module: pl.LightningModule, func_name, + def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): original_func_name = func_name[:] self.funcs_called_count[original_func_name] += 1