diff --git a/CHANGELOG.md b/CHANGELOG.md index 17dbd44c3e9eea..5a6863e46497ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -153,6 +153,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207)) +- Removed legacy code to include `step` dictionary returns in `callback_metrics` ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682)) + + - Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) diff --git a/docs/source/ecosystem/asr_nlp_tts.rst b/docs/source/ecosystem/asr_nlp_tts.rst index 49bed0a981a6e0..af9a7084583f2f 100644 --- a/docs/source/ecosystem/asr_nlp_tts.rst +++ b/docs/source/ecosystem/asr_nlp_tts.rst @@ -270,12 +270,12 @@ with PyTorch Lightning since every NeMo model is a Lightning Module. log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len ) wer_num, wer_denom = self._wer(predictions, transcript, transcript_len) - tensorboard_logs = { + self.log_dict({ 'train_loss': loss_value, 'training_batch_wer': wer_num / wer_denom, 'learning_rate': self._optimizer.param_groups[0]['lr'], - } - return {'loss': loss_value, 'log': tensorboard_logs} + }) + return loss_value Neural Types in NeMo ASR ------------------------ @@ -539,8 +539,8 @@ since every NeMo model is a Lightning Module. logits = self(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask) - tensorboard_logs = {'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']} - return {'loss': loss, 'log': tensorboard_logs} + self.log_dict({'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']}) + return loss ... Neural Types in NeMo NLP diff --git a/docs/source/ecosystem/bolts.rst b/docs/source/ecosystem/bolts.rst index 9133176cab912e..f3a4ab9c858be1 100644 --- a/docs/source/ecosystem/bolts.rst +++ b/docs/source/ecosystem/bolts.rst @@ -68,8 +68,8 @@ you can trust the implementations and use them to bootstrap your research much f loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1).long()) - logs = {"loss": loss} - return {"loss": loss, "log": logs} + self.log("loss", loss) + return loss ---------- diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 27815867301513..d9dea5979ae58a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -590,7 +590,7 @@ def _validate_monitor_key(self, trainer): m = ( f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" f" {list(metrics.keys())}. " - f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" + f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?" ) raise MisconfigurationException(m) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 7759c8028d3258..4a57b14efd89b3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -346,7 +346,6 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: # update callback_metrics logger_connector._callback_metrics.update(callback_metrics) - logger_connector._callback_metrics.pop("epoch", None) batch_pbar_metrics.pop("debug_epoch", None) return batch_pbar_metrics, batch_log_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a41e589216d219..cc89d4a00c460b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -78,7 +78,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: @property def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results.get(self.trainer._running_stage) # type: ignore + return self._cached_results.get(self.trainer._running_stage) def get_metrics(self, key: str) -> Dict: metrics_holder: MetricsHolder = getattr(self, f"_{key}") @@ -121,8 +121,6 @@ def cache_logged_metrics(self): def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging self.configure_logger(logger) - # todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders - # and assign here the desired value self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps self.trainer.log_every_n_steps = log_every_n_steps self.trainer.move_metrics_to_cpu = move_metrics_to_cpu @@ -185,9 +183,6 @@ def cache_training_step_metrics(self, opt_closure_result): batch_log_metrics = opt_closure_result.training_step_output.log_metrics logged_metrics_tmp.update(batch_log_metrics) - callback_metrics = opt_closure_result.training_step_output.callback_metrics - callback_metrics_tmp.update(callback_metrics) - batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end pbar_metrics_tmp.update(batch_pbar_metrics) @@ -210,9 +205,6 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): metrics (dict): Metric values grad_norm_dic (dict): Gradient norms step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` - log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training - steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients) - and global_step for the rest. """ # add gpu memory if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: @@ -339,9 +331,9 @@ def _track_callback_metrics(self, eval_results): if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(flat) - def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics): + def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics): # eval loop returns all metrics - dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics} + dataloader_result_metrics = {**prog_bar_metrics, **log_metrics} # add metrics to prog bar self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics) @@ -350,13 +342,6 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric if len(log_metrics) > 0: self.trainer.logger_connector.log_metrics(log_metrics, {}) - # track metrics for callbacks (all prog bar, logged and callback metrics) - callback_metrics.update(log_metrics) - callback_metrics.update(prog_bar_metrics) - self.trainer.logger_connector.callback_metrics.update(callback_metrics) - if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): - self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics) - if len(dataloader_result_metrics) > 0: self.eval_loop_results.append(dataloader_result_metrics) @@ -371,20 +356,16 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results): eval_results = [eval_results] num_loaders: int = self.trainer.evaluation_loop.num_dataloaders - prog_bar_metrics, log_metrics, callback_metrics = {}, {}, {} + prog_bar_metrics, log_metrics = {}, {} for result_idx, result in enumerate(eval_results): - _, prog_bar_metrics, log_metrics, callback_metrics = self.trainer.process_dict_result(result) + _, prog_bar_metrics, log_metrics = self.trainer.process_dict_result(result) if num_loaders > 1: - self.__process_eval_epoch_end_results_and_log_legacy_update( - prog_bar_metrics, log_metrics, callback_metrics - ) + self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics) if num_loaders == 1: - self.__process_eval_epoch_end_results_and_log_legacy_update( - prog_bar_metrics, log_metrics, callback_metrics - ) + self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics) def on_train_epoch_end(self): # inform cached logger connector epoch finished @@ -397,8 +378,6 @@ def log_train_epoch_end_metrics(self, epoch_output, num_optimizers): model = self.trainer.lightning_module - epoch_callback_metrics = {} - # ------------------------ # determine if using a result obj # ------------------------ @@ -426,10 +405,9 @@ def log_train_epoch_end_metrics(self, epoch_output, num_optimizers): # TODO: deprecate 1.0 else: - out = self.__run_legacy_training_epoch_end( - num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics + epoch_log_metrics, epoch_progress_bar_metrics = self.__run_legacy_training_epoch_end( + num_optimizers, epoch_output, model, is_result_obj ) - epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out # it will perform reduction over epoch and return log metrics cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics() @@ -447,9 +425,6 @@ def log_train_epoch_end_metrics(self, epoch_output, num_optimizers): self.log_metrics(epoch_log_metrics, {}) self._callback_metrics.update(epoch_log_metrics) - # add metrics to callbacks - self._callback_metrics.update(epoch_callback_metrics) - # add metrics to progress_bar and callbacks if len(epoch_progress_bar_metrics) > 0: self.add_progress_bar_metrics(epoch_progress_bar_metrics) @@ -481,9 +456,7 @@ def training_epoch_end(self, model, epoch_output, num_optimizers): # capture logging self.trainer.logger_connector.cache_logged_metrics() - def __run_legacy_training_epoch_end( - self, num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics - ): + def __run_legacy_training_epoch_end(self, num_optimizers, epoch_output, model, is_result_obj): epoch_log_metrics = {} epoch_progress_bar_metrics = {} @@ -514,7 +487,6 @@ def __run_legacy_training_epoch_end( _processed_outputs = self.trainer.process_dict_result(epoch_output) epoch_progress_bar_metrics = _processed_outputs[1] epoch_log_metrics = _processed_outputs[2] - epoch_callback_metrics = _processed_outputs[3] # -------------------------- # Structured Result (auto epoch end) @@ -522,7 +494,7 @@ def __run_legacy_training_epoch_end( elif is_result_obj: epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) - return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics + return epoch_log_metrics, epoch_progress_bar_metrics def __auto_reduce_results_on_epoch_end(self, epoch_output): epoch_log_metrics = {} diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 4e022695bc807c..f897c37828541e 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -20,6 +20,7 @@ from pytorch_lightning.utilities import DistributedType from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -32,8 +33,14 @@ class TrainerLoggingMixin(ABC): def metrics_to_scalars(self, metrics): new_metrics = {} + # TODO: this is duplicated in MetricsHolder. should be unified for k, v in metrics.items(): if isinstance(v, torch.Tensor): + if v.numel() != 1: + raise MisconfigurationException( + f"The metric `{k}` does not contain a single element" + f" thus it cannot be converted to float. Found `{v}`" + ) v = v.item() if isinstance(v, dict): @@ -71,22 +78,7 @@ def process_dict_result(self, output, train=False): if isinstance(output, torch.Tensor): progress_bar_metrics = {} log_metrics = {} - callback_metrics = {} - return output, progress_bar_metrics, log_metrics, callback_metrics - - # --------------- - # EXTRACT CALLBACK KEYS - # --------------- - # all keys not progress_bar or log are candidates for callbacks - callback_metrics = {} - if isinstance(output, Mapping): - for k, v in output.items(): - if k not in ['progress_bar', 'log']: - callback_metrics[k] = v - - if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): - num_gpus = self.num_gpus - callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus) + return output, progress_bar_metrics, log_metrics # --------------- # EXTRACT PROGRESS BAR KEYS @@ -143,17 +135,12 @@ def process_dict_result(self, output, train=False): if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): loss = self.reduce_distributed_output(loss, self.num_gpus) - # use every metric passed in as a candidate for callback - callback_metrics.update(progress_bar_metrics) - callback_metrics.update(log_metrics) - # detach all metrics for callbacks to prevent memory leaks # no .item() because it will slow things down - callback_metrics = recursive_detach(callback_metrics) progress_bar_metrics = recursive_detach(progress_bar_metrics) log_metrics = recursive_detach(log_metrics) - return loss, progress_bar_metrics, log_metrics, callback_metrics + return loss, progress_bar_metrics, log_metrics def reduce_distributed_output(self, output, num_gpus): if num_gpus <= 1: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6de43c2b945d60..f0f1d3e6b11e14 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -823,15 +823,6 @@ def run_sanity_check(self, ref_model): # run eval step _, eval_results = self.run_evaluation() - # allow no returns from eval - if eval_results is not None and len(eval_results) > 0: - # when we get a list back, used only the last item - if isinstance(eval_results, list): - eval_results = eval_results[-1] - - _, _, _, callback_metrics = self.process_dict_result(eval_results) - self.logger_connector.callback_metrics = callback_metrics - self.on_sanity_check_end() self._running_stage = stage diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1e655fe6e8fadc..15958e6bdab1c6 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -341,7 +341,6 @@ def _process_training_step_output(self, training_step_output, split_batch): batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], - callback_metrics=training_step_output[3], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): diff --git a/tests/base/model_valid_epoch_ends.py b/tests/base/model_valid_epoch_ends.py index dd29d355a4a98c..7b83670acacef3 100644 --- a/tests/base/model_valid_epoch_ends.py +++ b/tests/base/model_valid_epoch_ends.py @@ -43,9 +43,8 @@ def _mean(res, key): val_loss_mean = val_loss_mean.item() val_acc_mean = val_acc_mean.item() - metrics_dict = {'early_stop_on': val_loss_mean, 'val_acc': val_acc_mean} - results = {'progress_bar': metrics_dict, 'log': metrics_dict} - return results + self.log('early_stop_on', val_loss_mean, prog_bar=True) + self.log('val_acc', val_acc_mean, prog_bar=True) def validation_epoch_end__multiple_dataloaders(self, outputs): """ diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 2a15852fc6ee5a..cc619077ee1368 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -128,7 +128,7 @@ class ModelOverrideValidationReturn(BoringModel): def validation_epoch_end(self, outputs): loss = self.validation_return_values[self.current_epoch] - return {"test_val_loss": loss} + self.log("test_val_loss", loss) model = ModelOverrideValidationReturn() early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) @@ -220,7 +220,7 @@ class CurrentModel(BoringModel): def validation_epoch_end(self, outputs): losses = [8, 4, 2, 3, 4, 5, 8, 10] val_loss = losses[self.current_epoch] - self.log('abc', torch.tensor(val_loss)) + self.log('abc', val_loss) model = CurrentModel() @@ -234,28 +234,6 @@ def validation_epoch_end(self, outputs): assert trainer.current_epoch == 5, 'early_stopping failed' -def test_early_stopping_functionality_arbitrary_key(tmpdir): - """Tests whether early stopping works with a custom key and dictionary results on val step.""" - - class CurrentModel(BoringModel): - - def validation_epoch_end(self, outputs): - losses = [8, 4, 2, 3, 4, 5, 8, 10] - val_loss = losses[self.current_epoch] - return {'jiraffe': torch.tensor(val_loss)} - - model = CurrentModel() - - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[EarlyStopping(monitor='jiraffe')], - overfit_batches=0.20, - max_epochs=20, - ) - trainer.fit(model) - assert trainer.current_epoch >= 5, 'early_stopping failed' - - @pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)]) def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int): """Excepted Behaviour: @@ -272,7 +250,7 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: in when `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached. - Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) + Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) This test validate those expected behaviours """ @@ -309,7 +287,7 @@ def validation_epoch_end(self, outputs): self._count_decrease += 1 self._loss_value -= self._eps self._values.append(_mean) - return {"test_val_loss": _mean} + self.log('test_val_loss', _mean) model = Model(step_freeze) model.training_step_end = None diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f0..5e50543d37e5fd 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -51,7 +51,6 @@ def training_step(self, batch, batch_idx): def validation_epoch_end(self, outputs): outs = torch.stack([x['x'] for x in outputs]).mean() - self.log('epoch', self.current_epoch) self.log('val_acc', outs) @@ -721,12 +720,7 @@ def test_model_checkpoint_topk_all(tmpdir): seed_everything(1000) epochs = 3 - class CustomModel(LogInTwoMethods): - - def validation_epoch_end(self, outputs): - return {'epoch': self.current_epoch} - - model = CustomModel() + model = BoringModel() checkpoint_callback = ModelCheckpoint( dirpath=tmpdir, filename="{epoch}", @@ -894,7 +888,7 @@ class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) - return {"val_loss": loss} + self.log("val_loss", loss) model = ExtendedBoringModel() model.validation_epoch_end = None diff --git a/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py index 2aac7354c38f6d..0894acd5fe72d0 100644 --- a/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py @@ -125,49 +125,6 @@ def test_validation_step_arbitrary_dict_return(tmpdir): assert not model.validation_epoch_end_called -def test_validation_step_dict_return(tmpdir): - """ - Test that val step can return a dict with all the expected keys and they end up - in the correct place - """ - - model = DeterministicModel() - model.training_step = model.training_step__dict_return - model.validation_step = model.validation_step__dict_return - model.validation_step_end = None - model.validation_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - weights_summary=None, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=2, - ) - trainer.fit(model) - - # out are the results of the full loop - # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation() - assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 5 - assert len(eval_results) == 2 - assert eval_results[0]['log']['log_acc1'] == 12 - assert eval_results[1]['log']['log_acc1'] == 13 - - for k in ['val_loss', 'log', 'progress_bar']: - assert k in eval_results[0] - assert k in eval_results[1] - - # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) in [7, 8] - - # make sure correct steps were called - assert model.validation_step_called - assert not model.validation_step_end_called - assert not model.validation_epoch_end_called - - def test_val_step_step_end_no_return(tmpdir): """ Test that val step + val step end work (with no return in val step end) @@ -198,136 +155,3 @@ def test_val_step_step_end_no_return(tmpdir): assert model.validation_step_called assert model.validation_step_end_called assert not model.validation_epoch_end_called - - -def test_val_step_step_end(tmpdir): - """ - Test that val step + val step end work - """ - - model = DeterministicModel() - model.training_step = model.training_step__dict_return - model.validation_step = model.validation_step__dict_return - model.validation_step_end = model.validation_step_end - model.validation_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - weights_summary=None, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=2, - ) - trainer.fit(model) - - # out are the results of the full loop - # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation() - assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 6 - - callback_metrics = callback_metrics[0] - assert callback_metrics['val_step_end'] == 1802 - assert len(eval_results) == 2 - assert eval_results[0]['log']['log_acc1'] == 12 - assert eval_results[1]['log']['log_acc1'] == 13 - - for k in ['val_loss', 'log', 'progress_bar']: - assert k in eval_results[0] - assert k in eval_results[1] - - # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) in [8, 9] - - # make sure correct steps were called - assert model.validation_step_called - assert model.validation_step_end_called - assert not model.validation_epoch_end_called - - -def test_no_val_step_end(tmpdir): - """ - Test that val step + val epoch end - """ - - model = DeterministicModel() - model.training_step = model.training_step__dict_return - model.validation_step = model.validation_step__dict_return - model.validation_step_end = None - model.validation_epoch_end = model.validation_epoch_end - - trainer = Trainer( - default_root_dir=tmpdir, - weights_summary=None, - limit_train_batches=2, - limit_val_batches=3, - num_sanity_val_steps=0, - max_epochs=2 - ) - trainer.fit(model) - - # out are the results of the full loop - # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation() - assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 6 - assert len(eval_results) == 1 - - eval_results = eval_results[0] - assert 'val_step_end' not in eval_results - assert eval_results['val_epoch_end'] == 1233 - - for k in ['val_loss', 'log', 'progress_bar']: - assert k in eval_results - - # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) in [8, 9] - - # make sure correct steps were called - assert model.validation_step_called - assert not model.validation_step_end_called - assert model.validation_epoch_end_called - - -def test_full_val_loop(tmpdir): - """ - Test that val step + val step end + val epoch end - """ - - model = DeterministicModel() - model.training_step = model.training_step__dict_return - model.validation_step = model.validation_step__dict_return - model.validation_step_end = model.validation_step_end - model.validation_epoch_end = model.validation_epoch_end - - trainer = Trainer( - default_root_dir=tmpdir, - weights_summary=None, - limit_train_batches=2, - limit_val_batches=3, - num_sanity_val_steps=0, - max_epochs=2 - ) - trainer.fit(model) - - # out are the results of the full loop - # eval_results are output of _evaluate - callback_metrics, eval_results = trainer.run_evaluation() - assert len(callback_metrics) == 1 - assert len(callback_metrics[0]) == 7 - assert len(eval_results) == 1 - - eval_results = eval_results[0] - assert eval_results['val_step_end'] == 1802 - assert eval_results['val_epoch_end'] == 1233 - - for k in ['val_loss', 'log', 'progress_bar']: - assert k in eval_results - - # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) in [9, 10] - - # make sure correct steps were called - assert model.validation_step_called - assert model.validation_step_end_called - assert model.validation_epoch_end_called diff --git a/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_dict_return.py b/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_dict_return.py index 9c114f72080d87..3f60e6060d2ae1 100644 --- a/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_dict_return.py @@ -171,28 +171,6 @@ def test_result_obj_lr_scheduler_epoch(tmpdir): assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == 3 -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_result_obj_lr_scheduler_step(tmpdir): - """ - test that the LR scheduler was called at the correct time with the correct metrics - """ - model = DeterministicModel() - model.training_step = model.training_step__for_step_end_dict - model.training_step_end = model.training_step_end__dict - model.training_epoch_end = model.training_epoch_end__dict - model.val_dataloader = None - model.configure_optimizers = model.configure_optimizers__lr_on_plateau_step - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=2, - weights_summary=None, - ) - trainer.fit(model) - - assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == 8 - - def test_train_step_epoch_end(tmpdir): """ Checks train_step + training_epoch_end (NO training_step_end) diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index 674e2aeb6511b4..32bff96baf5667 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -372,11 +372,10 @@ def test_multi_dataloaders_add_suffix_properly(tmpdir): class TestModel(BoringModel): - def test_step(self, batch, batch_idx, dataloader_idx): + def test_step(self, batch, *args): output = self.layer(batch) loss = self.loss(batch, output) self.log("test_loss", loss, on_step=True, on_epoch=True) - return {"y": loss} def test_dataloader(self): return [ @@ -397,22 +396,19 @@ def test_dataloader(self): weights_summary=None, ) results = trainer.test(model) - assert "test_loss_epoch/dataloader_idx_0" in results[0] - assert "test_loss_epoch/dataloader_idx_1" in results[1] + + assert {"test_loss/dataloader_idx_0", "test_loss_epoch/dataloader_idx_0"} == set(results[0]) + assert {"test_loss/dataloader_idx_1", "test_loss_epoch/dataloader_idx_1"} == set(results[1]) def test_single_dataloader_no_suffix_added(tmpdir): class TestModel(BoringModel): - def test_step(self, batch, batch_idx): + def test_step(self, batch, *args): output = self.layer(batch) loss = self.loss(batch, output) self.log("test_loss", loss, on_step=True, on_epoch=True) - return {"y": loss} - - def test_dataloader(self): - return torch.utils.data.DataLoader(RandomDataset(32, 64)) model = TestModel() model.test_epoch_end = None @@ -427,9 +423,9 @@ def test_dataloader(self): weights_summary=None, ) results = trainer.test(model) + assert len(results) == 1 - # error : It is wrong there. `y` should equal test_loss_epoch - assert results[0]['test_loss'] == results[0]['y'] + assert {"test_loss", "test_loss_epoch"} == set(results[0]) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -849,7 +845,7 @@ def validation_step(self, batch, batch_idx): self.log('valid_loss_1', loss, on_step=False, on_epoch=True) self.log('valid_loss_2', loss, on_step=True, on_epoch=False) self.log('valid_loss_3', loss, on_step=False, on_epoch=False) - return {"val_loss": loss} + return {"val_loss": loss} # not added to callback_metrics def test_step(self, batch, batch_idx): output = self.layer(batch) @@ -926,7 +922,6 @@ def get_metrics_at_idx(idx): 'debug_epoch', 'valid_loss_1', 'test_loss', - 'val_loss', } assert set(trainer.callback_metrics) == expected_callback_metrics assert set(results[0]) == {'test_loss', 'debug_epoch'} diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index d14ed719403284..819248ffd6dcf8 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -461,7 +461,7 @@ class TestModel(BoringModel): def validation_step(self, batch, *args, **kwargs): output = self(batch) - return {"test": output} + self.log('test', output) def test_step(self, *args, **kwargs): return self.validation_step(*args, **kwargs) @@ -479,6 +479,27 @@ def test_step(self, *args, **kwargs): trainer.test(model) +def test_can_return_tensor_with_more_than_one_element(tmpdir): + """Ensure {validation,test}_step return values are not included as callback metrics. #6623""" + + class TestModel(BoringModel): + + def validation_step(self, batch, *args, **kwargs): + return {"val": torch.tensor([0, 1])} + + def test_step(self, batch, *args, **kwargs): + return {"test": torch.tensor([0, 1])} + + model = TestModel() + model.validation_epoch_end = None + model.test_epoch_end = None + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, progress_bar_refresh_rate=0) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + + def test_logging_to_progress_bar_with_reserved_key(tmpdir): """ Test that logging a metric with a reserved name to the progress bar raises a warning. """