From 242fb06dd696ac8c8ee370a750987c2f74b63780 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 25 Jul 2025 21:59:03 +0530 Subject: [PATCH 1/6] fix: rich progress bar error when resume training --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 7bb98e8a9058c..6063cdc360e9d 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -357,6 +357,13 @@ def refresh(self) -> None: def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._init_progress(trainer) + # Initialize the training progress bar here because + # `on_train_epoch_start` is not called when resuming from a mid-epoch restart + total_batches = self.total_train_batches + train_description = self._get_train_description(trainer.current_epoch) + assert self.progress is not None + self.train_progress_bar_id = self._add_task(total_batches, train_description) + @override def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._init_progress(trainer) From 1dcfbbf82a4f27da446031ea2ec70b536e74b212 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 25 Jul 2025 23:40:20 +0530 Subject: [PATCH 2/6] update --- .../callbacks/progress/rich_progress.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 6063cdc360e9d..bde8e7a3d98fe 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -357,13 +357,6 @@ def refresh(self) -> None: def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._init_progress(trainer) - # Initialize the training progress bar here because - # `on_train_epoch_start` is not called when resuming from a mid-epoch restart - total_batches = self.total_train_batches - train_description = self._get_train_description(trainer.current_epoch) - assert self.progress is not None - self.train_progress_bar_id = self._add_task(total_batches, train_description) - @override def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._init_progress(trainer) @@ -454,6 +447,14 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible: visible=visible, ) + def _initialize_progress_bar_id(self) -> None: + # Initialize the training progress bar here because + # `on_train_epoch_start` is not called when resuming from a mid-epoch restart + total_batches = self.total_train_batches + train_description = self._get_train_description(self.trainer.current_epoch) + assert self.progress is not None + self.train_progress_bar_id = self._add_task(total_batches, train_description) + def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None: if self.progress is not None and self.is_enabled: assert progress_bar_id is not None @@ -538,6 +539,9 @@ def on_train_batch_end( batch: Any, batch_idx: int, ) -> None: + if self.train_progress_bar_id is None: + # can happen when resuming from a mid-epoch restart + self._initialize_progress_bar_id() self._update(self.train_progress_bar_id, batch_idx + 1) self._update_metrics(trainer, pl_module) self.refresh() From 340e1d8969129574e64a2c6f0295c6fb827aa039 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 25 Jul 2025 23:42:04 +0530 Subject: [PATCH 3/6] update --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index bde8e7a3d98fe..c584852acd4b3 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -448,8 +448,6 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible: ) def _initialize_progress_bar_id(self) -> None: - # Initialize the training progress bar here because - # `on_train_epoch_start` is not called when resuming from a mid-epoch restart total_batches = self.total_train_batches train_description = self._get_train_description(self.trainer.current_epoch) assert self.progress is not None From edc88bd0512df48bc1107c27e39d6c3cb8b65c4b Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 26 Jul 2025 00:03:52 +0530 Subject: [PATCH 4/6] update --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index c584852acd4b3..e7268cd0efa93 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -448,9 +448,10 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible: ) def _initialize_progress_bar_id(self) -> None: + if self.is_disabled: + return total_batches = self.total_train_batches train_description = self._get_train_description(self.trainer.current_epoch) - assert self.progress is not None self.train_progress_bar_id = self._add_task(total_batches, train_description) def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None: From 20fa9efaf58b6b979c8046b33575c9820c0cd577 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 26 Jul 2025 00:38:08 +0530 Subject: [PATCH 5/6] update --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index e7268cd0efa93..70f6166d5cc94 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -448,8 +448,6 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible: ) def _initialize_progress_bar_id(self) -> None: - if self.is_disabled: - return total_batches = self.total_train_batches train_description = self._get_train_description(self.trainer.current_epoch) self.train_progress_bar_id = self._add_task(total_batches, train_description) @@ -538,7 +536,7 @@ def on_train_batch_end( batch: Any, batch_idx: int, ) -> None: - if self.train_progress_bar_id is None: + if self.train_progress_bar_id is None and not self.is_disabled: # can happen when resuming from a mid-epoch restart self._initialize_progress_bar_id() self._update(self.train_progress_bar_id, batch_idx + 1) From 67f12d3321990b0badcbfe492485fedacdbdcc1a Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 30 Jul 2025 12:08:14 +0530 Subject: [PATCH 6/6] change as reviewed --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 70f6166d5cc94..6aec230316d43 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -447,7 +447,7 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible: visible=visible, ) - def _initialize_progress_bar_id(self) -> None: + def _initialize_train_progress_bar_id(self) -> None: total_batches = self.total_train_batches train_description = self._get_train_description(self.trainer.current_epoch) self.train_progress_bar_id = self._add_task(total_batches, train_description) @@ -536,9 +536,9 @@ def on_train_batch_end( batch: Any, batch_idx: int, ) -> None: - if self.train_progress_bar_id is None and not self.is_disabled: + if not self.is_disabled and self.train_progress_bar_id is None: # can happen when resuming from a mid-epoch restart - self._initialize_progress_bar_id() + self._initialize_train_progress_bar_id() self._update(self.train_progress_bar_id, batch_idx + 1) self._update_metrics(trainer, pl_module) self.refresh()