Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate progress tracking into the progress bar #11213

Merged
merged 15 commits into from
Jan 6, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))


- Integrate the progress bar implementation with progress tracking ([#11213](https://github.com/PyTorchLightning/pytorch-lightning/pull/11213))


- The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))


Expand Down
70 changes: 24 additions & 46 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,52 +47,57 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):
"""

def __init__(self):

self._trainer = None
self._train_batch_idx = 0
self._val_batch_idx = 0
self._test_batch_idx = 0
self._predict_batch_idx = 0
self._trainer: Optional["pl.Trainer"] = None

@property
def trainer(self):
def trainer(self) -> "pl.Trainer":
return self._trainer

@property
def train_batch_idx(self) -> int:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""The current batch index being processed during training.
"""The number of batches processed during training.

Use this to update your progress bar.
"""
return self._train_batch_idx
if self.trainer is None:
return 0
return self.trainer.fit_loop.epoch_loop.batch_progress.current.processed
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def val_batch_idx(self) -> int:
"""The current batch index being processed during validation.
"""The number of batches processed during validation.

Use this to update your progress bar.
"""
return self._val_batch_idx
if self.trainer is None:
return 0
if self.trainer.state.fn == "fit":
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.current.processed
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed

@property
def test_batch_idx(self) -> int:
"""The current batch index being processed during testing.
"""The number of batches processed during testing.

Use this to update your progress bar.
"""
return self._test_batch_idx
if self.trainer is None:
return 0
return self.trainer.test_loop.epoch_loop.batch_progress.current.processed
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@property
def predict_batch_idx(self) -> int:
"""The current batch index being processed during predicting.
"""The number of batches processed during prediction.

Use this to update your progress bar.
"""
return self._predict_batch_idx
if self.trainer is None:
return 0
return self.trainer.predict_loop.epoch_loop.batch_progress.current.processed

@property
def total_train_batches(self) -> int:
"""The total number of training batches during training, which may change from epoch to epoch.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""The total number of training batches, which may change from epoch to epoch.

Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training
dataloader is of infinite size.
Expand All @@ -101,7 +106,7 @@ def total_train_batches(self) -> int:

@property
def total_val_batches(self) -> int:
"""The total number of validation batches during validation, which may change from epoch to epoch.
"""The total number of validation batches, which may change from epoch to epoch.

Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
dataloader is of infinite size.
Expand All @@ -115,7 +120,7 @@ def total_val_batches(self) -> int:

@property
def total_test_batches(self) -> int:
"""The total number of testing batches during testing, which may change from epoch to epoch.
"""The total number of testing batches, which may change from epoch to epoch.

Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is
of infinite size.
Expand All @@ -124,7 +129,7 @@ def total_test_batches(self) -> int:

@property
def total_predict_batches(self) -> int:
"""The total number of predicting batches during testing, which may change from epoch to epoch.
"""The total number of prediction batches, which may change from epoch to epoch.

Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
is of infinite size.
Expand Down Expand Up @@ -155,33 +160,6 @@ def print(self, *args, **kwargs):
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
self._trainer = trainer

def on_train_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._train_batch_idx += 1

def on_validation_start(self, trainer, pl_module):
self._val_batch_idx = 0

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._val_batch_idx += 1

def on_test_start(self, trainer, pl_module):
self._test_batch_idx = 0

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1

def on_predict_epoch_start(self, trainer, pl_module):
self._predict_batch_idx = 0

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._predict_batch_idx += 1

def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
r"""
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
Expand Down
14 changes: 0 additions & 14 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,35 +303,28 @@ def refresh(self) -> None:
self.progress.refresh()

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self._init_progress(trainer)

def on_predict_start(self, trainer, pl_module):
super().on_predict_start(trainer, pl_module)
self._init_progress(trainer)

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self._init_progress(trainer)

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
self._init_progress(trainer)

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
self.refresh()

def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
if self.progress is not None:
self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False)
self.refresh()

def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float("inf"):
Expand All @@ -354,7 +347,6 @@ def on_train_epoch_start(self, trainer, pl_module):
self.refresh()

def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
if self.total_val_batches > 0:
total_val_batches = self.total_val_batches
if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"):
Expand All @@ -379,7 +371,6 @@ def _should_update(self, current: int, total: int) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def on_validation_epoch_end(self, trainer, pl_module):
super().on_validation_epoch_end(trainer, pl_module)
if self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)

Expand All @@ -388,18 +379,15 @@ def on_test_epoch_start(self, trainer, pl_module):
self.refresh()

def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
self._update_metrics(trainer, pl_module)
self.refresh()

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
elif self.val_progress_bar_id is not None:
Expand All @@ -410,12 +398,10 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
self.refresh()

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
self.refresh()

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
self.refresh()

Expand Down
67 changes: 25 additions & 42 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def is_enabled(self) -> bool:
def is_disabled(self) -> bool:
return not self.is_enabled

@property
def _val_processed(self):
if self.trainer is None:
return 0
if self.trainer.state.fn == "fit":
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# use total in case validation runs more than once per training epoch
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed

def disable(self) -> None:
self._enabled = False

Expand Down Expand Up @@ -216,87 +225,74 @@ def on_sanity_check_end(self, trainer, pl_module):
self.val_progress_bar.close()

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()

def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float("inf") and total_val_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx)
self.main_progress_bar.total = convert_inf(total_batches)
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
if self._should_update(self.train_batch_idx):
self._update_bar(self.main_progress_bar)
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_enabled:
self._update_bar(self.main_progress_bar)
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not self.main_progress_bar.disable:
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.main_progress_bar.close()

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if trainer.sanity_checking:
reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx)
self.val_progress_bar.total = sum(trainer.num_sanity_val_batches)
else:
if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx)
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx):
self._update_bar(self.val_progress_bar)
_update_n(self.val_progress_bar, self.val_batch_idx)
if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
self._update_bar(self.main_progress_bar)
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_enabled:
self._update_bar(self.val_progress_bar)
_update_n(self.val_progress_bar, self._val_processed)

def on_validation_end(self, trainer, pl_module):
if self.main_progress_bar is not None and trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.val_progress_bar.close()

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.test_batch_idx):
self._update_bar(self.test_progress_bar)
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_enabled:
self._update_bar(self.test_progress_bar)
_update_n(self.test_progress_bar, self.test_batch_idx)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def on_test_end(self, trainer, pl_module):
self.test_progress_bar.close()

def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar = self.init_predict_tqdm()
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.predict_batch_idx):
self._update_bar(self.predict_progress_bar)
_update_n(self.predict_progress_bar, self.predict_batch_idx)

def on_predict_end(self, trainer, pl_module):
self.predict_progress_bar.close()
Expand All @@ -320,19 +316,7 @@ def print(
active_progress_bar.write(s, end=end, file=file, nolock=nolock)

def _should_update(self, idx: int) -> bool:
return self.is_enabled and (idx % self.refresh_rate == 0)

def _update_bar(self, bar: Optional[Tqdm]) -> None:
"""Updates the bar by the refresh rate without overshooting."""
if bar is None:
return
if bar.total is not None:
delta = min(self.refresh_rate, bar.total - bar.n)
else:
# infinite / unknown size
delta = self.refresh_rate
if delta > 0:
bar.update(delta)
return self.refresh_rate and idx % self.refresh_rate == 0

@staticmethod
def _resolve_refresh_rate(refresh_rate: int) -> int:
Expand All @@ -353,8 +337,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
return x


def reset(bar: Tqdm, total: Optional[int] = None, current: int = 0) -> None:
"""Resets the tqdm bar to the desired position and sets a new total, unless it is disabled."""
def _update_n(bar: _tqdm, value: int) -> None:
if not bar.disable:
bar.reset(total=convert_inf(total))
bar.n = current
bar.n = value
bar.refresh()
Loading