Skip to content

Commit

Permalink
Add is_last_batch to progress tracking (#9657)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Sep 23, 2021
1 parent fd4f2f6 commit 8dcba38
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
* Add `BatchProgress` and integrate `TrainingEpochLoop.is_last_batch` ([#9657](https://github.com/PyTorchLightning/pytorch-lightning/pull/9657))
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))
Expand Down
23 changes: 10 additions & 13 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.loops.optimization.closure import OutputResult
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -43,9 +43,7 @@ def __init__(self, min_steps: int, max_steps: int):
self.max_steps: int = max_steps

self.global_step: int = 0
# manually tracking which is the last batch is necessary for iterable dataset support
self.is_last_batch: Optional[bool] = None
self.batch_progress = Progress()
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop: Optional[TrainingBatchLoop] = None
Expand Down Expand Up @@ -94,18 +92,16 @@ def reset(self) -> None:
assert self.batch_loop is not None
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
self.batch_progress.current.reset_on_restart()
self.batch_progress.reset_on_restart()
self.scheduler_progress.current.reset_on_restart()
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()

self.is_last_batch = False

# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]

if not self.restarting or self._num_training_batches_reached():
self.batch_progress.current.reset()
self.scheduler_progress.current.reset()
self.batch_progress.reset_on_epoch()
self.scheduler_progress.reset_on_epoch()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()

def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Expand All @@ -127,6 +123,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
StopIteration: When the epoch is canceled by the user returning -1
"""
batch_idx, (batch, is_last) = next(self.dataloader_iter)
self.batch_progress.is_last_batch = is_last

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
Expand All @@ -139,8 +136,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_processed()

self.is_last_batch = is_last

# when returning -1 from train_step, we end epoch early
if batch_output.signal == -1:
raise StopIteration
Expand Down Expand Up @@ -178,7 +173,7 @@ def on_advance_end(self):
# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self._should_check_val_fx(self.batch_idx, self.is_last_batch)
should_check_val = self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch)
if should_check_val:
self.trainer.validating = True
self._run_validation()
Expand Down Expand Up @@ -259,7 +254,9 @@ def _accumulated_batches_reached(self) -> bool:

def _num_training_batches_reached(self) -> bool:
"""Checks if we are in the last batch or if there are more batches to follow."""
return self.batch_progress.current.ready == self.trainer.num_training_batches or self.is_last_batch
return (
self.batch_progress.current.ready == self.trainer.num_training_batches or self.batch_progress.is_last_batch
)

def _should_accumulate(self) -> bool:
"""Checks if the optimizer step should be performed or gradients should be accumulated for the current
Expand Down
36 changes: 30 additions & 6 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int)
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))

def reset_on_epoch(self) -> None:
self.current.reset()

def reset_on_restart(self) -> None:
self.current.reset_on_restart()

Expand All @@ -158,8 +161,9 @@ def load_state_dict(self, state_dict: dict) -> None:

@dataclass
class DataLoaderProgress(Progress):
"""Tracks the dataloader progress These counters are local to a trainer rank. By default, they are not globally
synced across all ranks.
"""Tracks dataloader progress.
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
Args:
total: Tracks the total dataloader progress.
Expand All @@ -170,10 +174,30 @@ class DataLoaderProgress(Progress):
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)


@dataclass
class BatchProgress(Progress):
"""Tracks batch progress.
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
Args:
total: Tracks the total dataloader progress.
current: Tracks the current dataloader progress.
is_last_batch: Whether the batch is the last one. This is useful for iterable datasets.
"""

is_last_batch: bool = False

def reset_on_epoch(self) -> None:
super().reset_on_epoch()
self.is_last_batch = False


@dataclass
class SchedulerProgress(Progress):
"""Tracks the scheduler progress. These counters are local to a trainer rank. By default, they are not globally
synced across all ranks.
"""Tracks scheduler progress.
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
Args:
total: Tracks the total scheduler progress.
Expand All @@ -197,8 +221,8 @@ class OptimizerProgress(BaseProgress):
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))

def reset_on_epoch(self) -> None:
self.step.current.reset()
self.zero_grad.current.reset()
self.step.reset_on_epoch()
self.zero_grad.reset_on_epoch()

def reset_on_restart(self) -> None:
self.step.reset_on_restart()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,7 +1892,7 @@ def min_steps(self) -> Optional[int]:

@property
def is_last_batch(self) -> bool:
return self.fit_loop.epoch_loop.is_last_batch
return self.fit_loop.epoch_loop.batch_progress.is_last_batch

@property
def fit_loop(self) -> FitLoop:
Expand Down
1 change: 1 addition & 0 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_loops_state_dict_structure():
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"is_last_batch": False,
},
"epoch_loop.scheduler_progress": {
"total": {"ready": 0, "completed": 0},
Expand Down
11 changes: 10 additions & 1 deletion tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import pytest
import torch
from torch.utils.data import DataLoader

from pl_examples.bug_report_model import RandomDataset
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
Expand Down Expand Up @@ -443,6 +445,7 @@ def configure_optimizers_multiple(self):
"processed": stop_batch,
"completed": stop_batch,
},
"is_last_batch": False,
},
"epoch_loop.scheduler_progress": {
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
Expand Down Expand Up @@ -548,13 +551,16 @@ def configure_optimizers_multiple(self):

return optimizers, lr_schedulers

def train_dataloader(self):
# override to test the `is_last_batch` value
return DataLoader(RandomDataset(32, n_batches))

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
accumulate_grad_batches=accumulate_grad_batches,
progress_bar_refresh_rate=0,
Expand All @@ -563,6 +569,8 @@ def configure_optimizers_multiple(self):
)
trainer.fit(model)

assert trainer.num_training_batches == n_batches

ckpt_path = trainer.checkpoint_callback.best_model_path
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)
Expand Down Expand Up @@ -607,6 +615,7 @@ def configure_optimizers_multiple(self):
"processed": n_batches,
"completed": n_batches,
},
"is_last_batch": True,
},
"epoch_loop.scheduler_progress": {
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
Expand Down

0 comments on commit 8dcba38

Please sign in to comment.