Skip to content

Commit

Permalink
Add rich for Model Summary
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 committed Aug 17, 2021
1 parent 513700d commit 5a3215f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 12 deletions.
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def render(self, task) -> Text:
class RichProgressBar(ProgressBarBase):
def __init__(self, refresh_rate: int = 1):
if not _RICH_AVAILABLE:
raise MisconfigurationException("Rich progress bar is not available")
raise MisconfigurationException(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
)
super().__init__()
self._refresh_rate = refresh_rate
self._enabled = True
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import Callback, RichProgressBar
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop
Expand Down Expand Up @@ -1043,8 +1043,9 @@ def _pre_training_routine(self):

# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
use_rich = isinstance(self.progress_bar_callback, RichProgressBar)
max_depth = ModelSummary.MODES[self.weights_summary]
summarize(ref_model, max_depth=max_depth)
summarize(ref_model, max_depth=max_depth, use_rich=use_rich)

# on pretrain routine end
self.on_pretrain_routine_end()
Expand Down
79 changes: 70 additions & 9 deletions pytorch_lightning/utilities/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.warnings import WarningCache

if _RICH_AVAILABLE:
from rich.console import Console
from rich.table import Table

log = logging.getLogger(__name__)
warning_cache = WarningCache()

Expand Down Expand Up @@ -299,12 +303,7 @@ def _forward_example_input(self) -> None:
model(input_)
model.train(mode) # restore mode of module

def __str__(self):
"""
Makes a summary listing with:
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
"""
def _get_summary_data(self):
arrays = [
[" ", list(map(str, range(len(self._layer_summary))))],
["Name", self.layer_names],
Expand All @@ -314,6 +313,62 @@ def __str__(self):
if self._model.example_input_array is not None:
arrays.append(["In sizes", self.in_sizes])
arrays.append(["Out sizes", self.out_sizes])

return arrays

def print_rich_summary(self):

if not _RICH_AVAILABLE:
raise MisconfigurationException(
"`print_rich_summary` requires `rich` to be installed." " Install it by running `pip install rich`."
)

arrays = self._get_summary_data()
total_parameters = self.total_parameters
trainable_parameters = self.trainable_parameters
model_size = self.model_size

console = Console()

table = Table(title="Model Summary")

table.add_column(" ")
table.add_column("Name", arrays[1][1], justify="left", style="cyan", no_wrap=True)
table.add_column("Type", arrays[2][1], style="magenta")
table.add_column("Params", arrays[3][1], justify="right", style="green")

rows = list(zip(*(arr[1] for arr in arrays)))
for row in rows:
table.add_row(*row)

console.print(table)

# Formatting
s = "{:<{}}"

parameters = []
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
parameters.append(s.format(get_human_readable_count(param), 10))

grid = Table.grid(expand=True)
grid.add_column()
grid.add_column()

grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}")
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")

console.print(grid)

def __str__(self):
"""
Makes a summary listing with:
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
"""
arrays = self._get_summary_data()

total_parameters = self.total_parameters
trainable_parameters = self.trainable_parameters
model_size = self.model_size
Expand Down Expand Up @@ -435,7 +490,10 @@ def _is_lazy_weight_tensor(p: Tensor) -> bool:


def summarize(
lightning_module: "pl.LightningModule", mode: Optional[str] = "top", max_depth: Optional[int] = None
lightning_module: "pl.LightningModule",
mode: Optional[str] = "top",
max_depth: Optional[int] = None,
use_rich: bool = False,
) -> Optional[ModelSummary]:
"""
Summarize the LightningModule specified by `lightning_module`.
Expand Down Expand Up @@ -467,5 +525,8 @@ def summarize(
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
else:
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
log.info("\n" + str(model_summary))
if use_rich:
model_summary.print_rich_summary()
else:
log.info("\n" + str(model_summary))
return model_summary

0 comments on commit 5a3215f

Please sign in to comment.