From ef152e34f04c7a273d2fd5f824ffee76240a3b13 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sat, 11 Sep 2021 21:55:46 +0530 Subject: [PATCH 1/8] feat: Add RichModelSummary callback --- pytorch_lightning/callbacks/__init__.py | 2 ++ .../callbacks/rich_model_summary.py | 27 +++++++++++++++++++ .../trainer/connectors/callback_connector.py | 22 ++++++++++++--- 3 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 pytorch_lightning/callbacks/rich_model_summary.py diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index f02518c14bbe6..514aee1b130db 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -24,6 +24,7 @@ from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.pruning import ModelPruning from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining +from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor @@ -47,5 +48,6 @@ "QuantizationAwareTraining", "StochasticWeightAveraging", "Timer", + "RichModelSummary", "RichProgressBar", ] diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py new file mode 100644 index 0000000000000..312116d6c7cd4 --- /dev/null +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +from pytorch_lightning.callbacks import ModelSummary + + +class RichModelSummary(ModelSummary): + @staticmethod + def summarize( + summary_data: List[List[Union[str, List[str]]]], + total_parameters: int, + trainable_parameters: int, + model_size: float, + ) -> None: + pass diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 2a2b5d15de0fd..ea532250533a8 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -15,7 +15,15 @@ from datetime import timedelta from typing import Dict, List, Optional, Union -from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ModelSummary, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks import ( + Callback, + ModelCheckpoint, + ModelSummary, + ProgressBar, + ProgressBarBase, + RichModelSummary, + RichProgressBar, +) from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -59,8 +67,6 @@ def on_trainer_init( # responsible to stop the training when max_time is reached. self._configure_timer_callback(max_time) - self._configure_model_summary_callback(weights_summary) - # init progress bar if process_position != 0: rank_zero_deprecation( @@ -70,6 +76,9 @@ def on_trainer_init( ) self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) + # configure the ModelSummary callback + self._configure_model_summary_callback(weights_summary) + # push all checkpoint callbacks to the end # it is important that these are the last callbacks to run self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks) @@ -102,7 +111,12 @@ def _configure_model_summary_callback(self, weights_summary: Optional[str] = Non f" but got {weights_summary}", ) max_depth = ModelSummaryMode.get_max_depth(weights_summary) - model_summary = ModelSummary(max_depth=max_depth) + if self.trainer._progress_bar_callback is not None and isinstance( + self.trainer._progress_bar_callback, RichProgressBar + ): + model_summary = RichModelSummary(max_depth=max_depth) + else: + model_summary = ModelSummary(max_depth=max_depth) self.trainer.callbacks.append(model_summary) def _configure_swa_callbacks(self): From 0883d670da4edb79a108e360761b6a432eaacb6b Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 16 Sep 2021 00:29:12 +0530 Subject: [PATCH 2/8] Update callback --- .../callbacks/rich_model_summary.py | 45 ++++++++++++++++++- .../trainer/connectors/callback_connector.py | 2 +- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index 312116d6c7cd4..ac925456b007a 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -14,6 +14,12 @@ from typing import List, Union from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.utilities.imports import _RICH_AVAILABLE +from pytorch_lightning.utilities.model_summary import get_human_readable_count + +if _RICH_AVAILABLE: + from rich.console import Console + from rich.table import Table class RichModelSummary(ModelSummary): @@ -24,4 +30,41 @@ def summarize( trainable_parameters: int, model_size: float, ) -> None: - pass + + console = Console() + + table = Table(title="Model Summary") + + table.add_column(" ") + table.add_column("Name", justify="left", style="cyan", no_wrap=True) + table.add_column("Type", style="magenta") + table.add_column("Params", justify="right", style="green") + + # print(summary_data) + # if self._model.example_input_array is not None: + # table.add_column("In sizes", justify="right", style="green") + # table.add_column("Out sizes", justify="right", style="green") + + rows = list(zip(*(arr[1] for arr in summary_data))) + for row in rows: + table.add_row(*row) + + console.print(table) + + # Formatting + s = "{:<{}}" + + parameters = [] + for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]: + parameters.append(s.format(get_human_readable_count(param), 10)) + + grid = Table.grid(expand=True) + grid.add_column() + grid.add_column() + + grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}") + grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}") + grid.add_row(f"[bold]Total params[/]: {parameters[2]}") + grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}") + + console.print(grid) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index ea532250533a8..cafad831cbb30 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -21,9 +21,9 @@ ModelSummary, ProgressBar, ProgressBarBase, - RichModelSummary, RichProgressBar, ) +from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException From ae1f9c852c2a78bd4a863dfeeac6f209944670e4 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 16 Sep 2021 00:51:42 +0530 Subject: [PATCH 3/8] Add tests & update changelog --- CHANGELOG.md | 3 ++ pytorch_lightning/callbacks/__init__.py | 4 +-- .../callbacks/rich_model_summary.py | 17 ++++++--- tests/callbacks/test_rich_model_summary.py | 35 +++++++++++++++++++ 4 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 tests/callbacks/test_rich_model_summary.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b6884971bf4e..378bd834c924c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -133,6 +133,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) +- Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) + + ### Changed - `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)). diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 514aee1b130db..98cf5df7cafda 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -46,8 +46,8 @@ "ProgressBar", "ProgressBarBase", "QuantizationAwareTraining", - "StochasticWeightAveraging", - "Timer", "RichModelSummary", "RichProgressBar", + "StochasticWeightAveraging", + "Timer", ] diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index ac925456b007a..11dfbc47a0ce9 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -14,6 +14,7 @@ from typing import List, Union from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from pytorch_lightning.utilities.model_summary import get_human_readable_count @@ -23,6 +24,13 @@ class RichModelSummary(ModelSummary): + def __init__(self, max_depth: int = 1) -> None: + if not _RICH_AVAILABLE: + raise MisconfigurationException( + "`RichModelSummary` requires `rich` to be installed. Install it by running `pip install rich`." + ) + super().__init__(max_depth) + @staticmethod def summarize( summary_data: List[List[Union[str, List[str]]]], @@ -40,10 +48,11 @@ def summarize( table.add_column("Type", style="magenta") table.add_column("Params", justify="right", style="green") - # print(summary_data) - # if self._model.example_input_array is not None: - # table.add_column("In sizes", justify="right", style="green") - # table.add_column("Out sizes", justify="right", style="green") + column_names = list(zip(*summary_data))[0] + + for column_name in ["In sizes", "Out sizes"]: + if column_name in column_names: + table.add_column(column_name, justify="right", style="green") rows = list(zip(*(arr[1] for arr in summary_data))) for row in rows: diff --git a/tests/callbacks/test_rich_model_summary.py b/tests/callbacks/test_rich_model_summary.py new file mode 100644 index 0000000000000..99d557251fdfb --- /dev/null +++ b/tests/callbacks/test_rich_model_summary.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import RichModelSummary, RichProgressBar +from pytorch_lightning.utilities.imports import _RICH_AVAILABLE +from tests.helpers.runif import RunIf + + +@RunIf(rich=True) +def test_rich_model_summary_callback(): + + trainer = Trainer(callbacks=RichProgressBar()) + + assert any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks) + assert isinstance(trainer.progress_bar_callback, RichProgressBar) + + +def test_rich_progress_bar_import_error(): + + if not _RICH_AVAILABLE: + with pytest.raises(ImportError, match="`RichModelSummary` requires `rich` to be installed."): + Trainer(callbacks=RichModelSummary()) From cc4bd35bda70ae0a605f4bc8f861b3c551f57cc2 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 16 Sep 2021 00:55:36 +0530 Subject: [PATCH 4/8] Update callbacks.rst --- docs/source/extensions/callbacks.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index d5cfdfc8111c0..ad61c10a7bd3b 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -106,8 +106,10 @@ Lightning has a few built-in callbacks. LearningRateMonitor ModelCheckpoint ModelPruning + ModelSummary ProgressBar ProgressBarBase + RichModelSummary RichProgressBar QuantizationAwareTraining StochasticWeightAveraging From 2f8e67c725a053fb911d534be02034b045f5e7b5 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 16 Sep 2021 21:46:51 +0530 Subject: [PATCH 5/8] Address reviews --- pyproject.toml | 1 + .../callbacks/rich_model_summary.py | 43 ++++++++++++++++--- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d205af7f0a1f3..9981d6827e33d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ ignore_errors = "True" module = [ "pytorch_lightning.callbacks.model_summary", "pytorch_lightning.callbacks.pruning", + "pytorch_lightning.callbacks.rich_model_summary", "pytorch_lightning.loops.optimization.*", "pytorch_lightning.loops.evaluation_loop", "pytorch_lightning.trainer.connectors.checkpoint_connector", diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index 11dfbc47a0ce9..5eeff8817ee2e 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -14,7 +14,6 @@ from typing import List, Union from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from pytorch_lightning.utilities.model_summary import get_human_readable_count @@ -24,9 +23,44 @@ class RichModelSummary(ModelSummary): + r""" + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule` + with `rich text formatting `_. + + Install it with pip: + + .. code-block:: bash + + pip install rich + + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import RichModelSummary + + trainer = Trainer(callbacks=RichModelSummary()) + + You could also enable it using the :class:`~pytorch_lightning.callbacks.RichProgressBar` + + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import RichProgressBar + + trainer = Trainer(callbacks=RichProgressBar()) + + Args: + max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the + layer summary off. + + Raises: + ImportError: + If required `rich` package is not installed on the device. + """ + def __init__(self, max_depth: int = 1) -> None: if not _RICH_AVAILABLE: - raise MisconfigurationException( + raise ImportError( "`RichModelSummary` requires `rich` to be installed. Install it by running `pip install rich`." ) super().__init__(max_depth) @@ -60,12 +94,9 @@ def summarize( console.print(table) - # Formatting - s = "{:<{}}" - parameters = [] for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]: - parameters.append(s.format(get_human_readable_count(param), 10)) + parameters.append("{:<{}}".format(get_human_readable_count(param), 10)) grid = Table.grid(expand=True) grid.add_column() From d801654851d28385bb50d5095fca759be848edac Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 17 Sep 2021 04:42:57 +0530 Subject: [PATCH 6/8] Fix mypy --- pytorch_lightning/callbacks/rich_model_summary.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index 5eeff8817ee2e..a2981829e04bc 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import List, Tuple from pytorch_lightning.callbacks import ModelSummary from pytorch_lightning.utilities.imports import _RICH_AVAILABLE @@ -67,7 +67,7 @@ def __init__(self, max_depth: int = 1) -> None: @staticmethod def summarize( - summary_data: List[List[Union[str, List[str]]]], + summary_data: List[Tuple[str, List[str]]], total_parameters: int, trainable_parameters: int, model_size: float, @@ -96,7 +96,7 @@ def summarize( parameters = [] for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]: - parameters.append("{:<{}}".format(get_human_readable_count(param), 10)) + parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10)) grid = Table.grid(expand=True) grid.add_column() From 6dd81bcc424c00204a608b4b1f88cbb4a8d5de1b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 17 Sep 2021 10:46:57 +0100 Subject: [PATCH 7/8] Better styling --- pytorch_lightning/callbacks/rich_model_summary.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index a2981829e04bc..0751bca92cf8f 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -75,18 +75,17 @@ def summarize( console = Console() - table = Table(title="Model Summary") - - table.add_column(" ") - table.add_column("Name", justify="left", style="cyan", no_wrap=True) - table.add_column("Type", style="magenta") - table.add_column("Params", justify="right", style="green") + table = Table(header_style="bold magenta") + table.add_column(" ", style="dim") + table.add_column("Name", justify="left", no_wrap=True) + table.add_column("Type") + table.add_column("Params", justify="right") column_names = list(zip(*summary_data))[0] for column_name in ["In sizes", "Out sizes"]: if column_name in column_names: - table.add_column(column_name, justify="right", style="green") + table.add_column(column_name, justify="right", style="white") rows = list(zip(*(arr[1] for arr in summary_data))) for row in rows: From 17a3d769a461f377d3bf43b56f73e13cf0e699eb Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 17 Sep 2021 11:23:25 +0100 Subject: [PATCH 8/8] Updated wording --- pytorch_lightning/callbacks/rich_model_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index 0751bca92cf8f..2e55665c4433e 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -40,7 +40,7 @@ class RichModelSummary(ModelSummary): trainer = Trainer(callbacks=RichModelSummary()) - You could also enable it using the :class:`~pytorch_lightning.callbacks.RichProgressBar` + You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar` .. code-block:: python