From 8dcba38e0e13c288c9af8746d8edfe3097714926 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 Sep 2021 14:54:41 +0200 Subject: [PATCH] Add `is_last_batch` to progress tracking (#9657) --- CHANGELOG.md | 1 + .../loops/epoch/training_epoch_loop.py | 23 ++++++------ pytorch_lightning/trainer/progress.py | 36 +++++++++++++++---- pytorch_lightning/trainer/trainer.py | 2 +- tests/loops/test_loop_state_dict.py | 1 + tests/loops/test_loops.py | 11 +++++- 6 files changed, 53 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2701344a4f3e6..45c1328193ab2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Progress tracking * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)) + * Add `BatchProgress` and integrate `TrainingEpochLoop.is_last_batch` ([#9657](https://github.com/PyTorchLightning/pytorch-lightning/pull/9657)) * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)) * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) * Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3d7f36477c55e..f829c20e557b1 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -20,7 +20,7 @@ from pytorch_lightning.loops.optimization.closure import OutputResult from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import Progress, SchedulerProgress +from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -43,9 +43,7 @@ def __init__(self, min_steps: int, max_steps: int): self.max_steps: int = max_steps self.global_step: int = 0 - # manually tracking which is the last batch is necessary for iterable dataset support - self.is_last_batch: Optional[bool] = None - self.batch_progress = Progress() + self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None @@ -94,18 +92,16 @@ def reset(self) -> None: assert self.batch_loop is not None assert self.batch_loop.optimizer_loop is not None if self.restarting: - self.batch_progress.current.reset_on_restart() + self.batch_progress.reset_on_restart() self.scheduler_progress.current.reset_on_restart() self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() - self.is_last_batch = False - # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] if not self.restarting or self._num_training_batches_reached(): - self.batch_progress.current.reset() - self.scheduler_progress.current.reset() + self.batch_progress.reset_on_epoch() + self.scheduler_progress.reset_on_epoch() self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch() def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: @@ -127,6 +123,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: StopIteration: When the epoch is canceled by the user returning -1 """ batch_idx, (batch, is_last) = next(self.dataloader_iter) + self.batch_progress.is_last_batch = is_last if not self.trainer.data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): @@ -139,8 +136,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_processed() - self.is_last_batch = is_last - # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: raise StopIteration @@ -178,7 +173,7 @@ def on_advance_end(self): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self._should_check_val_fx(self.batch_idx, self.is_last_batch) + should_check_val = self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch) if should_check_val: self.trainer.validating = True self._run_validation() @@ -259,7 +254,9 @@ def _accumulated_batches_reached(self) -> bool: def _num_training_batches_reached(self) -> bool: """Checks if we are in the last batch or if there are more batches to follow.""" - return self.batch_progress.current.ready == self.trainer.num_training_batches or self.is_last_batch + return ( + self.batch_progress.current.ready == self.trainer.num_training_batches or self.batch_progress.is_last_batch + ) def _should_accumulate(self) -> bool: """Checks if the optimizer step should be performed or gradients should be accumulated for the current diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 6c2d95c6b8d56..eb6872b6a7c83 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -148,6 +148,9 @@ def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) + def reset_on_epoch(self) -> None: + self.current.reset() + def reset_on_restart(self) -> None: self.current.reset_on_restart() @@ -158,8 +161,9 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass class DataLoaderProgress(Progress): - """Tracks the dataloader progress These counters are local to a trainer rank. By default, they are not globally - synced across all ranks. + """Tracks dataloader progress. + + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: total: Tracks the total dataloader progress. @@ -170,10 +174,30 @@ class DataLoaderProgress(Progress): current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) +@dataclass +class BatchProgress(Progress): + """Tracks batch progress. + + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. + + Args: + total: Tracks the total dataloader progress. + current: Tracks the current dataloader progress. + is_last_batch: Whether the batch is the last one. This is useful for iterable datasets. + """ + + is_last_batch: bool = False + + def reset_on_epoch(self) -> None: + super().reset_on_epoch() + self.is_last_batch = False + + @dataclass class SchedulerProgress(Progress): - """Tracks the scheduler progress. These counters are local to a trainer rank. By default, they are not globally - synced across all ranks. + """Tracks scheduler progress. + + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: total: Tracks the total scheduler progress. @@ -197,8 +221,8 @@ class OptimizerProgress(BaseProgress): zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker)) def reset_on_epoch(self) -> None: - self.step.current.reset() - self.zero_grad.current.reset() + self.step.reset_on_epoch() + self.zero_grad.reset_on_epoch() def reset_on_restart(self) -> None: self.step.reset_on_restart() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 012fea147e47c..581ff11554cb3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1892,7 +1892,7 @@ def min_steps(self) -> Optional[int]: @property def is_last_batch(self) -> bool: - return self.fit_loop.epoch_loop.is_last_batch + return self.fit_loop.epoch_loop.batch_progress.is_last_batch @property def fit_loop(self) -> FitLoop: diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index ad5d0159036d5..0459e0033e46a 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -51,6 +51,7 @@ def test_loops_state_dict_structure(): "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "is_last_batch": False, }, "epoch_loop.scheduler_progress": { "total": {"ready": 0, "completed": 0}, diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index a9a24d1638522..f84062ad0e70a 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -20,7 +20,9 @@ import pytest import torch +from torch.utils.data import DataLoader +from pl_examples.bug_report_model import RandomDataset from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loops import Loop, TrainingBatchLoop @@ -443,6 +445,7 @@ def configure_optimizers_multiple(self): "processed": stop_batch, "completed": stop_batch, }, + "is_last_batch": False, }, "epoch_loop.scheduler_progress": { "total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps}, @@ -548,13 +551,16 @@ def configure_optimizers_multiple(self): return optimizers, lr_schedulers + def train_dataloader(self): + # override to test the `is_last_batch` value + return DataLoader(RandomDataset(32, n_batches)) + model = TestModel() model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=n_epochs, - limit_train_batches=n_batches, limit_val_batches=0, accumulate_grad_batches=accumulate_grad_batches, progress_bar_refresh_rate=0, @@ -563,6 +569,8 @@ def configure_optimizers_multiple(self): ) trainer.fit(model) + assert trainer.num_training_batches == n_batches + ckpt_path = trainer.checkpoint_callback.best_model_path assert os.path.exists(ckpt_path) checkpoint = torch.load(ckpt_path) @@ -607,6 +615,7 @@ def configure_optimizers_multiple(self): "processed": n_batches, "completed": n_batches, }, + "is_last_batch": True, }, "epoch_loop.scheduler_progress": { "total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},