From 84c507c4df5f5c336deb19ce7f70fa02329f39f6 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 28 Jul 2020 03:26:55 +0530 Subject: [PATCH] Fix max_batches with fast_dev_run. (#2581) * Fix fast_dev_run to run for all val_dataloaders * fast_dev_run check * changelog * explicit * limit_batches with fast_dev_run in init * add test * whitespace and comment fix * comment and assertion * added tests * Fix fast_dev_run to run for all val_dataloaders * fast_dev_run check * changelog * explicit * limit_batches with fast_dev_run in init * add test * whitespace and comment fix * comment and assertion * added tests * added tests * added tests * added tests * update rtol * Revert "update rtol" This reverts commit 4320329540798c112cf45dcbd6f677993e4c6ad6. * added tests Co-authored-by: William Falcon --- CHANGELOG.md | 2 + pytorch_lightning/callbacks/progress.py | 19 ++---- pytorch_lightning/trainer/__init__.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 10 ++- pytorch_lightning/trainer/trainer.py | 11 +++- pytorch_lightning/trainer/training_loop.py | 8 +-- pytorch_lightning/utilities/debugging.py | 39 ++++++++++++ tests/callbacks/test_progress_bar.py | 2 + tests/models/test_grad_norm.py | 2 +- tests/trainer/test_dataloaders.py | 64 +++++++++++++++++++- 10 files changed, 130 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a8df1196f7ab..af692fdd0f938 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) +- Fixed `fast_dev_run` to run for all dataloaders ([#2581](https://github.com/PyTorchLightning/pytorch-lightning/pull/2581)) + - Fixed `save_dir` in loggers getting ignored by default value of `weights_save_path` when user did not specify `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) - Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 2ff744a03c3ab..0acdbcc7509ea 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -88,8 +88,7 @@ def total_train_batches(self) -> int: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. """ - total_train_batches = 1 if self.trainer.fast_dev_run else self.trainer.num_training_batches - return total_train_batches + return self.trainer.num_training_batches @property def total_val_batches(self) -> int: @@ -98,13 +97,10 @@ def total_val_batches(self) -> int: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ - trainer = self.trainer total_val_batches = 0 - if trainer.fast_dev_run and trainer.val_dataloaders is not None: - total_val_batches = len(trainer.val_dataloaders) - elif self.trainer.enable_validation: - is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0 - total_val_batches = sum(trainer.num_val_batches) if is_val_epoch else 0 + if not self.trainer.disable_validation: + is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 return total_val_batches @property @@ -114,12 +110,7 @@ def total_test_batches(self) -> int: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ - if self.trainer.fast_dev_run: - total_test_batches = len(self.trainer.test_dataloaders) - else: - total_test_batches = self.trainer.num_test_batches - total_test_batches = sum(total_test_batches) - return total_test_batches + return sum(self.trainer.num_test_batches) def disable(self): """ diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 7e188ab97492c..15fe2f23d3922 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -393,7 +393,7 @@ def on_train_end(self, trainer, pl_module): # default used by the Trainer trainer = Trainer(fast_dev_run=False) - # runs 1 train, val, test batch and program ends + # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) gpus diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 97f118b4da157..f52f9c12f3c83 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -306,7 +306,7 @@ def _evaluate( if batch is None: continue - # stop short when on fast_dev_run (sets max_batch=1) + # stop short when running on limited batches if batch_idx >= dl_max_batches: break @@ -350,6 +350,9 @@ def _evaluate( self.__eval_add_step_metrics(output) + # track debug metrics + self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output) + outputs.append(dl_outputs) # --------------------- @@ -513,14 +516,9 @@ def run_evaluation(self, test_mode: bool = False): dataloaders = self.val_dataloaders max_batches = self.num_val_batches - # enable fast_dev_run without val loop if dataloaders is None: return [], [] - # cap max batches to 1 when using fast_dev_run - if self.fast_dev_run: - max_batches = [1] - # Validation/Test begin callbacks if test_mode: self.on_test_start() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 876380ee22cab..b774ce13e1ef0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -278,7 +278,7 @@ def __init__( check_val_every_n_epoch: Check val every n train epochs. - fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). + fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. @@ -505,7 +505,11 @@ def __init__( self.max_steps = max_steps self.min_steps = min_steps - self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps + if num_sanity_val_steps == -1: + self.num_sanity_val_steps = float("inf") + else: + self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches) + # Backward compatibility, TODO: remove in v0.9.0 if print_nan_grads: rank_zero_warn( @@ -528,6 +532,9 @@ def __init__( self.fast_dev_run = fast_dev_run if self.fast_dev_run: + limit_train_batches = 1 + limit_val_batches = 1 + limit_test_batches = 1 self.num_sanity_val_steps = 0 self.max_epochs = 1 rank_zero_info( diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d7a5332272c04..7f12d20151c24 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -401,7 +401,7 @@ def train(self): met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: - if (met_min_epochs and met_min_steps) or self.fast_dev_run: + if (met_min_epochs and met_min_steps): self.run_training_teardown() return else: @@ -507,7 +507,7 @@ def run_training_epoch(self): # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val(batch_idx, is_last_batch) - if self.fast_dev_run or should_check_val: + if should_check_val: self.run_evaluation(test_mode=False) # ----------------------------------------- @@ -530,7 +530,7 @@ def run_training_epoch(self): # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches - if self.fast_dev_run or self.should_stop: + if self.should_stop: break # let ddp devices catch up when using horovod @@ -548,7 +548,7 @@ def run_training_epoch(self): def check_checkpoint_callback(self, should_check_val): # when no val loop is present or fast-dev-run still need to call checkpoints # TODO bake this logic into the checkpoint callback - should_activate = not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val) + should_activate = not self.is_overridden('validation_step') and not should_check_val if should_activate: checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] [c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks] diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 490356938fb6d..892450230da5a 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -1,4 +1,5 @@ import os +from collections import Counter class InternalDebugger(object): @@ -10,6 +11,8 @@ def __init__(self, trainer): self.logged_metrics = [] self.pbar_added_metrics = [] self.saved_losses = [] + self.saved_val_losses = [] + self.saved_test_losses = [] self.early_stopping_history = [] self.checkpoint_callback_history = [] @@ -23,6 +26,21 @@ def track_train_loss_history(self, batch_idx, loss): loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()} self.saved_losses.append(loss_dict) + def track_eval_loss_history(self, test_mode, batch_idx, dataloader_idx, output): + if self.enabled: + loss_dict = { + 'sanity_check': self.trainer.running_sanity_check, + 'dataloader_idx': dataloader_idx, + 'batch_idx': batch_idx, + 'epoch': self.trainer.current_epoch, + 'output': output + } + + if test_mode: + self.saved_test_losses.append(loss_dict) + else: + self.saved_val_losses.append(loss_dict) + def track_pbar_metrics_history(self, metrics): if self.enabled: metrics['debug_epoch'] = self.trainer.current_epoch @@ -52,3 +70,24 @@ def track_checkpointing_history(self, filepath): 'filepath': filepath } self.checkpoint_callback_history.append(debug_dict) + + @property + def num_seen_sanity_check_batches(self): + count = len([x for x in self.saved_val_losses if x['sanity_check']]) + return count + + @property + def num_seen_val_check_batches(self): + counts = Counter() + for x in self.saved_val_losses: + if not x['sanity_check']: + counts.update({x['dataloader_idx']: 1}) + return counts + + @property + def num_seen_test_check_batches(self): + counts = Counter() + for x in self.saved_test_losses: + if not x['sanity_check']: + counts.update({x['dataloader_idx']: 1}) + return counts diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index f621e70228012..23743dc5dcb2c 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -116,6 +116,8 @@ def test_progress_bar_fast_dev_run(tmpdir): fast_dev_run=True, ) + trainer.fit(model) + progress_bar = trainer.progress_bar_callback assert 1 == progress_bar.total_train_batches # total val batches are known only after val dataloaders have reloaded diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index d7978965a3cfe..7483167755a8f 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -43,7 +43,7 @@ def on_after_backward(self): @pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf']) -def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): +def test_grad_tracking(tmpdir, norm_type, rtol=1e-2): os.environ['PL_DEV_DEBUG'] = '1' # rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index e76ef0e556352..333c64550fbba 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1,3 +1,4 @@ +import os import platform from unittest.mock import patch @@ -306,6 +307,8 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim ) def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for val & test dataloaders passed with batch limit as number""" + os.environ['PL_DEV_DEBUG'] = '1' + model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length @@ -323,16 +326,75 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v limit_test_batches=limit_test_batches, ) trainer.fit(model) + + # ------------------------------------------- + # MAKE SURE THE TRAINER SET THE CORRECT VALUES + # ------------------------------------------- assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders) trainer.test(ckpt_path=None) # when the limit is greater than the number of test batches it should be the num in loaders + test_dataloader_lengths = [len(x) for x in model.test_dataloader()] if limit_test_batches > 1e10: - assert trainer.num_test_batches == [len(x) for x in model.test_dataloader()] + assert trainer.num_test_batches == test_dataloader_lengths else: assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders) + # ------------------------------------------- + # make sure we actually saw the expected num of batches + # ------------------------------------------- + num_val_dataloaders = len(model.val_dataloader()) + num_test_dataloaders = len(model.test_dataloader()) + if limit_train_batches > 0: + + # make sure val batches are as expected + assert len(trainer.dev_debugger.num_seen_val_check_batches) == num_val_dataloaders + for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_val_check_batches.items(): + assert num_batches == limit_val_batches + + # make sure test batches are as expected + assert len(trainer.dev_debugger.num_seen_test_check_batches) == num_test_dataloaders + for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_test_check_batches.items(): + if limit_test_batches > 1e10: + assert num_batches == test_dataloader_lengths[dataloader_idx] + else: + assert num_batches == limit_test_batches + + +def test_dataloaders_with_fast_dev_run(tmpdir): + """Verify num_batches for train, val & test dataloaders passed with fast_dev_run = True""" + os.environ['PL_DEV_DEBUG'] = '1' + + model = EvalModelTemplate() + model.val_dataloader = model.val_dataloader__multiple_mixed_length + model.test_dataloader = model.test_dataloader__multiple_mixed_length + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders + model.test_epoch_end = model.test_epoch_end__multiple_dataloaders + + # train, multiple val and multiple test dataloaders passed with fast_dev_run = True + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + fast_dev_run=True, + ) + assert trainer.max_epochs == 1 + assert trainer.num_sanity_val_steps == 0 + + trainer.fit(model) + assert not trainer.disable_validation + assert trainer.num_training_batches == 1 + assert trainer.num_val_batches == [1] * len(trainer.val_dataloaders) + + trainer.test(ckpt_path=None) + assert trainer.num_test_batches == [1] * len(trainer.test_dataloaders) + + # verify sanity check batches match as expected + num_val_dataloaders = len(model.val_dataloader()) + assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders + @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) def test_mixing_of_dataloader_options(tmpdir, ckpt_path):