diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b2d8491eba9a..aaeb48bbbb2fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` + + - Deprecated `LightningModule.model_size` ([#8343](https://github.com/PyTorchLightning/pytorch-lightning/pull/8343)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7606bce0fc0c6..e4d563338ed76 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -33,7 +33,6 @@ from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks -from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO @@ -44,6 +43,7 @@ from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import get_model_size_mb +from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT @@ -1707,6 +1707,10 @@ def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None """ Summarize this LightningModule. + .. deprecated:: v1.5 + This method was deprecated in v1.5 in favor of `pytorch_lightning.utilities.model_summary.summarize` + and will be removed in v1.7. + Args: mode: Can be either ``'top'`` (summarize only direct submodules) or ``'full'`` (summarize all layers). @@ -1719,24 +1723,13 @@ def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None Return: The model summary object """ - model_summary = None - - # temporary mapping from mode to max_depth - if max_depth is None: - if mode in ModelSummary.MODES: - max_depth = ModelSummary.MODES[mode] - rank_zero_deprecation( - f"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4" - f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior." - ) - model_summary = ModelSummary(self, max_depth=max_depth) - elif mode is not None: - raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}") - else: - model_summary = ModelSummary(self, max_depth=max_depth) + warning_cache.deprecation( + "The `LightningModule.summarize` method is deprecated in v1.5 and will be removed in v1.7. " + "Use `pytorch_lightning.utilities.model_summary.summarize` instead.", + stacklevel=6, + ) - log.info("\n" + str(model_summary)) - return model_summary + return summarize(self, mode, max_depth) def freeze(self) -> None: r""" diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index fcc2f54ed61d0..d61164e6d88cb 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -11,478 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.utilities import rank_zero_deprecation -import os -import shutil -import subprocess -from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +rank_zero_deprecation( + "`pytorch_lightning.core.memory.get_memory_profile` and" + " `pytorch_lightning.core.memory.get_gpu_memory_map` have been moved" + " to `pytorch_lightning.utilities.memory` since v1.5 and will be removed in v1.7." +) -import numpy as np -import torch -import torch.nn as nn -from torch import Tensor -from torch.utils.hooks import RemovableHandle +# To support backward compatibility as get_memory_profile and get_gpu_memory_map have been moved +from pytorch_lightning.utilities.memory import get_gpu_memory_map, get_memory_profile # noqa: E402, F401 # isort: skip -from pytorch_lightning.utilities import AMPType, DeviceType -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 -from pytorch_lightning.utilities.warnings import WarningCache +rank_zero_deprecation( + "`pytorch_lightning.core.memory.LayerSummary` and" + " `pytorch_lightning.core.memory.ModelSummary` have been moved" + " to `pytorch_lightning.utilities.model_summary` since v1.5 and will be removed in v1.7." +) -warning_cache = WarningCache() - -PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] -UNKNOWN_SIZE = "?" - - -class LayerSummary: - """ - Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`. - It collects the following information: - - - Type of the layer (e.g. Linear, BatchNorm1d, ...) - - Input shape - - Output shape - - Number of parameters - - The input and output shapes are only known after the example input array was - passed through the model. - - Example:: - - >>> model = torch.nn.Conv2d(3, 8, 3) - >>> summary = LayerSummary(model) - >>> summary.num_parameters - 224 - >>> summary.layer_type - 'Conv2d' - >>> output = model(torch.rand(1, 3, 5, 5)) - >>> summary.in_size - [1, 3, 5, 5] - >>> summary.out_size - [1, 8, 3, 3] - - Args: - module: A module to summarize - - """ - - def __init__(self, module: nn.Module): - super().__init__() - self._module = module - self._hook_handle = self._register_hook() - self._in_size = None - self._out_size = None - - def __del__(self): - self.detach_hook() - - def _register_hook(self) -> Optional[RemovableHandle]: - """ - Registers a hook on the module that computes the input- and output size(s) on the first forward pass. - If the hook is called, it will remove itself from the from the module, meaning that - recursive models will only record their input- and output shapes once. - Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. - - Return: - A handle for the installed hook, or ``None`` if registering the hook is not possible. - """ - - def hook(module, inp, out): - if len(inp) == 1: - inp = inp[0] - self._in_size = parse_batch_shape(inp) - self._out_size = parse_batch_shape(out) - self._hook_handle.remove() - - handle = None - if not isinstance(self._module, torch.jit.ScriptModule): - handle = self._module.register_forward_hook(hook) - return handle - - def detach_hook(self): - """ - Removes the forward hook if it was not already removed in the forward pass. - Will be called after the summary is created. - """ - if self._hook_handle is not None: - self._hook_handle.remove() - - @property - def in_size(self) -> Union[str, List]: - return self._in_size or UNKNOWN_SIZE - - @property - def out_size(self) -> Union[str, List]: - return self._out_size or UNKNOWN_SIZE - - @property - def layer_type(self) -> str: - """Returns the class name of the module.""" - return str(self._module.__class__.__name__) - - @property - def num_parameters(self) -> int: - """Returns the number of parameters in this module.""" - return sum(np.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) - - -class ModelSummary: - """ - Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. - - Args: - model: The model to summarize (also referred to as the root module). - mode: Can be one of - - - `top` (default): only the top-level modules will be recorded (the children of the root module) - - `full`: summarizes all layers and their submodules in the root module - - .. deprecated:: v1.4 - This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6. - - max_depth: Maximum depth of modules to show. Use -1 to show all modules or 0 to show no - summary. Defaults to 1. - - The string representation of this summary prints a table with columns containing - the name, type and number of parameters for each layer. - - The root module may also have an attribute ``example_input_array`` as shown in the example below. - If present, the root module will be called with it as input to determine the - intermediate input- and output shapes of all layers. Supported are tensors and - nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?` - in the summary table. The summary will also display `?` for layers not used in the forward pass. - - Example:: - - >>> import pytorch_lightning as pl - >>> class LitModel(pl.LightningModule): - ... - ... def __init__(self): - ... super().__init__() - ... self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512)) - ... self.example_input_array = torch.zeros(10, 256) # optional - ... - ... def forward(self, x): - ... return self.net(x) - ... - >>> model = LitModel() - >>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE - | Name | Type | Params | In sizes | Out sizes - ------------------------------------------------------------ - 0 | net | Sequential | 132 K | [10, 256] | [10, 512] - ------------------------------------------------------------ - 132 K Trainable params - 0 Non-trainable params - 132 K Total params - 0.530 Total estimated model params size (MB) - >>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE - | Name | Type | Params | In sizes | Out sizes - -------------------------------------------------------------- - 0 | net | Sequential | 132 K | [10, 256] | [10, 512] - 1 | net.0 | Linear | 131 K | [10, 256] | [10, 512] - 2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512] - -------------------------------------------------------------- - 132 K Trainable params - 0 Non-trainable params - 132 K Total params - 0.530 Total estimated model params size (MB) - """ - - MODES = dict(top=1, full=-1) # TODO: remove in v1.6 - - def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1): - self._model = model - - # temporary mapping from mode to max_depth - if max_depth is None or mode is not None: - if mode in ModelSummary.MODES: - max_depth = ModelSummary.MODES[mode] - from pytorch_lightning.utilities import rank_zero_deprecation - - rank_zero_deprecation( - f"Argument `mode` in `ModelSummary` is deprecated in v1.4" - f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour." - ) - else: - from pytorch_lightning.utilities.exceptions import MisconfigurationException - - raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.") - - if not isinstance(max_depth, int) or max_depth < -1: - raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.") - - self._max_depth = max_depth - self._layer_summary = self.summarize() - # 1 byte -> 8 bits - # TODO: how do we compute precisin_megabytes in case of mixed precision? - precision = self._model.precision if isinstance(self._model.precision, int) else 32 - self._precision_megabytes = (precision / 8.0) * 1e-6 - - @property - def named_modules(self) -> List[Tuple[str, nn.Module]]: - if self._max_depth == 0: - mods = [] - elif self._max_depth == 1: - # the children are the top-level modules - mods = self._model.named_children() - else: - mods = self._model.named_modules() - mods = list(mods)[1:] # do not include root module (LightningModule) - return list(mods) - - @property - def layer_names(self) -> List[str]: - return list(self._layer_summary.keys()) - - @property - def layer_types(self) -> List[str]: - return [layer.layer_type for layer in self._layer_summary.values()] - - @property - def in_sizes(self) -> List: - return [layer.in_size for layer in self._layer_summary.values()] - - @property - def out_sizes(self) -> List: - return [layer.out_size for layer in self._layer_summary.values()] - - @property - def param_nums(self) -> List[int]: - return [layer.num_parameters for layer in self._layer_summary.values()] - - @property - def total_parameters(self) -> int: - return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters()) - - @property - def trainable_parameters(self) -> int: - return sum( - p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad - ) - - @property - def model_size(self) -> float: - # todo: seems it does not work with quantized models - it returns 0.0 - return self.total_parameters * self._precision_megabytes - - def summarize(self) -> Dict[str, LayerSummary]: - summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) - if self._model.example_input_array is not None: - self._forward_example_input() - for layer in summary.values(): - layer.detach_hook() - - if self._max_depth >= 1: - # remove summary entries with depth > max_depth - for k in [k for k in summary if k.count(".") >= self._max_depth]: - del summary[k] - - return summary - - def _forward_example_input(self) -> None: - """Run the example input through each layer to get input- and output sizes.""" - model = self._model - trainer = self._model.trainer - - input_ = model.example_input_array - input_ = model._apply_batch_transfer_handler(input_) - - if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: - model.forward = torch.cuda.amp.autocast()(model.forward) - - mode = model.training - model.eval() - with torch.no_grad(): - # let the model hooks collect the input- and output shapes - if isinstance(input_, (list, tuple)): - model(*input_) - elif isinstance(input_, dict): - model(**input_) - else: - model(input_) - model.train(mode) # restore mode of module - - def __str__(self): - """ - Makes a summary listing with: - - Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size - """ - arrays = [ - [" ", list(map(str, range(len(self._layer_summary))))], - ["Name", self.layer_names], - ["Type", self.layer_types], - ["Params", list(map(get_human_readable_count, self.param_nums))], - ] - if self._model.example_input_array is not None: - arrays.append(["In sizes", self.in_sizes]) - arrays.append(["Out sizes", self.out_sizes]) - total_parameters = self.total_parameters - trainable_parameters = self.trainable_parameters - model_size = self.model_size - - return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays) - - def __repr__(self): - return str(self) - - -def parse_batch_shape(batch: Any) -> Union[str, List]: - if hasattr(batch, "shape"): - return list(batch.shape) - - if isinstance(batch, (list, tuple)): - shape = [parse_batch_shape(el) for el in batch] - return shape - - return UNKNOWN_SIZE - - -def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str: - """ - Takes in a number of arrays, each specifying a column in - the summary table, and combines them all into one big - string defining the summary table that are nicely formatted. - """ - n_rows = len(cols[0][1]) - n_cols = 1 + len(cols) - - # Get formatting width of each column - col_widths = [] - for c in cols: - col_width = max(len(str(a)) for a in c[1]) if n_rows else 0 - col_width = max(col_width, len(c[0])) # minimum length is header length - col_widths.append(col_width) - - # Formatting - s = "{:<{}}" - total_width = sum(col_widths) + 3 * n_cols - header = [s.format(c[0], l) for c, l in zip(cols, col_widths)] - - # Summary = header + divider + Rest of table - summary = " | ".join(header) + "\n" + "-" * total_width - for i in range(n_rows): - line = [] - for c, l in zip(cols, col_widths): - line.append(s.format(str(c[1][i]), l)) - summary += "\n" + " | ".join(line) - summary += "\n" + "-" * total_width - - summary += "\n" + s.format(get_human_readable_count(trainable_parameters), 10) - summary += "Trainable params" - summary += "\n" + s.format(get_human_readable_count(total_parameters - trainable_parameters), 10) - summary += "Non-trainable params" - summary += "\n" + s.format(get_human_readable_count(total_parameters), 10) - summary += "Total params" - summary += "\n" + s.format(get_formatted_model_size(model_size), 10) - summary += "Total estimated model params size (MB)" - - return summary - - -def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: - """Get a profile of the current memory usage. - - Args: - mode: There are two modes: - - - 'all' means return memory for all gpus - - 'min_max' means return memory for max and min - - Return: - A dictionary in which the keys are device ids as integers and - values are memory usage as integers in MB. - If mode is 'min_max', the dictionary will also contain two additional keys: - - - 'min_gpu_mem': the minimum memory usage in MB - - 'max_gpu_mem': the maximum memory usage in MB - """ - memory_map = get_gpu_memory_map() - - if mode == "min_max": - min_index, min_memory = min(memory_map.items(), key=lambda item: item[1]) - max_index, max_memory = max(memory_map.items(), key=lambda item: item[1]) - - memory_map = {"min_gpu_mem": min_memory, "max_gpu_mem": max_memory} - - return memory_map - - -def get_gpu_memory_map() -> Dict[str, int]: - """ - Get the current gpu usage. - - Return: - A dictionary in which the keys are device ids as integers and - values are memory usage as integers in MB. - """ - result = subprocess.run( - [shutil.which("nvidia-smi"), "--query-gpu=memory.used", "--format=csv,nounits,noheader"], - encoding="utf-8", - # capture_output=True, # valid for python version >=3.7 - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 - check=True, - ) - - # Convert lines into a dictionary - gpu_memory = [float(x) for x in result.stdout.strip().split(os.linesep)] - gpu_memory_map = {f"gpu_id: {gpu_id}/memory.used (MB)": memory for gpu_id, memory in enumerate(gpu_memory)} - return gpu_memory_map - - -def get_formatted_model_size(total_model_size: float) -> float: - return f"{total_model_size:,.3f}" - - -def get_human_readable_count(number: int) -> str: - """ - Abbreviates an integer number with K, M, B, T for thousands, millions, - billions and trillions, respectively. - - Examples: - >>> get_human_readable_count(123) - '123 ' - >>> get_human_readable_count(1234) # (one thousand) - '1.2 K' - >>> get_human_readable_count(2e6) # (two million) - '2.0 M' - >>> get_human_readable_count(3e9) # (three billion) - '3.0 B' - >>> get_human_readable_count(4e14) # (four hundred trillion) - '400 T' - >>> get_human_readable_count(5e15) # (more than trillion) - '5,000 T' - - Args: - number: a positive integer number - - Return: - A string formatted according to the pattern described above. - - """ - assert number >= 0 - labels = PARAMETER_NUM_UNITS - num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) - num_groups = int(np.ceil(num_digits / 3)) - num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions - shift = -3 * (num_groups - 1) - number = number * (10 ** shift) - index = num_groups - 1 - if index < 1 or number >= 100: - return f"{int(number):,d} {labels[index]}" - - return f"{number:,.1f} {labels[index]}" - - -def _is_lazy_weight_tensor(p: Tensor) -> bool: - if _TORCH_GREATER_EQUAL_1_8: - from torch.nn.parameter import UninitializedParameter - - if isinstance(p, UninitializedParameter): - warning_cache.warn( - "A layer with UninitializedParameter was found. " - "Thus, the total number of parameters detected may be inaccurate." - ) - return True - return False +# To support backward compatibility as LayerSummary and ModelSummary have been moved +from pytorch_lightning.utilities.model_summary import LayerSummary, ModelSummary # noqa: E402, F401 # isort: skip diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index c413f658bd724..50317a36744a6 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -17,11 +17,10 @@ import torch import pytorch_lightning as pl -from pytorch_lightning.core import memory from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities import DeviceType, memory from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e3a52a09d1bc8..31b933c595c03 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -27,7 +27,6 @@ from pytorch_lightning.accelerators import Accelerator, IPUAccelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop @@ -81,6 +80,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -1033,7 +1033,7 @@ def _pre_training_routine(self): # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: max_depth = ModelSummary.MODES[self.weights_summary] - ref_model.summarize(max_depth=max_depth) + summarize(ref_model, max_depth=max_depth) # on pretrain routine end self.on_pretrain_routine_end() diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index e6edb6c925ae2..ff8cac116427f 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -14,7 +14,10 @@ import gc import os +import shutil +import subprocess import uuid +from typing import Dict, Union import torch from torch.nn import Module @@ -92,6 +95,57 @@ def garbage_collection_cuda(): raise +def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: + """Get a profile of the current memory usage. + + Args: + mode: There are two modes: + + - 'all' means return memory for all gpus + - 'min_max' means return memory for max and min + + Return: + A dictionary in which the keys are device ids as integers and + values are memory usage as integers in MB. + If mode is 'min_max', the dictionary will also contain two additional keys: + + - 'min_gpu_mem': the minimum memory usage in MB + - 'max_gpu_mem': the maximum memory usage in MB + """ + memory_map = get_gpu_memory_map() + + if mode == "min_max": + min_index, min_memory = min(memory_map.items(), key=lambda item: item[1]) + max_index, max_memory = max(memory_map.items(), key=lambda item: item[1]) + + memory_map = {"min_gpu_mem": min_memory, "max_gpu_mem": max_memory} + + return memory_map + + +def get_gpu_memory_map() -> Dict[str, int]: + """ + Get the current gpu usage. + + Return: + A dictionary in which the keys are device ids as integers and + values are memory usage as integers in MB. + """ + result = subprocess.run( + [shutil.which("nvidia-smi"), "--query-gpu=memory.used", "--format=csv,nounits,noheader"], + encoding="utf-8", + # capture_output=True, # valid for python version >=3.7 + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 + check=True, + ) + + # Convert lines into a dictionary + gpu_memory = [float(x) for x in result.stdout.strip().split(os.linesep)] + gpu_memory_map = {f"gpu_id: {gpu_id}/memory.used (MB)": memory for gpu_id, memory in enumerate(gpu_memory)} + return gpu_memory_map + + def get_model_size_mb(model: Module) -> float: """ Calculates the size of a Module in megabytes by saving the model to a temporary file and reading its size. diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py new file mode 100644 index 0000000000000..4834da54220ef --- /dev/null +++ b/pytorch_lightning/utilities/model_summary.py @@ -0,0 +1,471 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor +from torch.utils.hooks import RemovableHandle + +import pytorch_lightning as pl +from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.warnings import WarningCache + +log = logging.getLogger(__name__) +warning_cache = WarningCache() + +PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] +UNKNOWN_SIZE = "?" + + +class LayerSummary: + """ + Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + It collects the following information: + + - Type of the layer (e.g. Linear, BatchNorm1d, ...) + - Input shape + - Output shape + - Number of parameters + + The input and output shapes are only known after the example input array was + passed through the model. + + Example:: + + >>> model = torch.nn.Conv2d(3, 8, 3) + >>> summary = LayerSummary(model) + >>> summary.num_parameters + 224 + >>> summary.layer_type + 'Conv2d' + >>> output = model(torch.rand(1, 3, 5, 5)) + >>> summary.in_size + [1, 3, 5, 5] + >>> summary.out_size + [1, 8, 3, 3] + + Args: + module: A module to summarize + + """ + + def __init__(self, module: nn.Module): + super().__init__() + self._module = module + self._hook_handle = self._register_hook() + self._in_size = None + self._out_size = None + + def __del__(self): + self.detach_hook() + + def _register_hook(self) -> Optional[RemovableHandle]: + """ + Registers a hook on the module that computes the input- and output size(s) on the first forward pass. + If the hook is called, it will remove itself from the from the module, meaning that + recursive models will only record their input- and output shapes once. + Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. + + Return: + A handle for the installed hook, or ``None`` if registering the hook is not possible. + """ + + def hook(module, inp, out): + if len(inp) == 1: + inp = inp[0] + self._in_size = parse_batch_shape(inp) + self._out_size = parse_batch_shape(out) + self._hook_handle.remove() + + handle = None + if not isinstance(self._module, torch.jit.ScriptModule): + handle = self._module.register_forward_hook(hook) + return handle + + def detach_hook(self): + """ + Removes the forward hook if it was not already removed in the forward pass. + Will be called after the summary is created. + """ + if self._hook_handle is not None: + self._hook_handle.remove() + + @property + def in_size(self) -> Union[str, List]: + return self._in_size or UNKNOWN_SIZE + + @property + def out_size(self) -> Union[str, List]: + return self._out_size or UNKNOWN_SIZE + + @property + def layer_type(self) -> str: + """Returns the class name of the module.""" + return str(self._module.__class__.__name__) + + @property + def num_parameters(self) -> int: + """Returns the number of parameters in this module.""" + return sum(np.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + + +class ModelSummary: + """ + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + + Args: + model: The model to summarize (also referred to as the root module). + mode: Can be one of + + - `top` (default): only the top-level modules will be recorded (the children of the root module) + - `full`: summarizes all layers and their submodules in the root module + + .. deprecated:: v1.4 + This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6. + + max_depth: Maximum depth of modules to show. Use -1 to show all modules or 0 to show no + summary. Defaults to 1. + + The string representation of this summary prints a table with columns containing + the name, type and number of parameters for each layer. + + The root module may also have an attribute ``example_input_array`` as shown in the example below. + If present, the root module will be called with it as input to determine the + intermediate input- and output shapes of all layers. Supported are tensors and + nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?` + in the summary table. The summary will also display `?` for layers not used in the forward pass. + + Example:: + + >>> import pytorch_lightning as pl + >>> class LitModel(pl.LightningModule): + ... + ... def __init__(self): + ... super().__init__() + ... self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512)) + ... self.example_input_array = torch.zeros(10, 256) # optional + ... + ... def forward(self, x): + ... return self.net(x) + ... + >>> model = LitModel() + >>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE + | Name | Type | Params | In sizes | Out sizes + ------------------------------------------------------------ + 0 | net | Sequential | 132 K | [10, 256] | [10, 512] + ------------------------------------------------------------ + 132 K Trainable params + 0 Non-trainable params + 132 K Total params + 0.530 Total estimated model params size (MB) + >>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE + | Name | Type | Params | In sizes | Out sizes + -------------------------------------------------------------- + 0 | net | Sequential | 132 K | [10, 256] | [10, 512] + 1 | net.0 | Linear | 131 K | [10, 256] | [10, 512] + 2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512] + -------------------------------------------------------------- + 132 K Trainable params + 0 Non-trainable params + 132 K Total params + 0.530 Total estimated model params size (MB) + """ + + MODES = dict(top=1, full=-1) # TODO: remove in v1.6 + + def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1): + self._model = model + + # temporary mapping from mode to max_depth + if max_depth is None or mode is not None: + if mode in ModelSummary.MODES: + max_depth = ModelSummary.MODES[mode] + rank_zero_deprecation( + f"Argument `mode` in `ModelSummary` is deprecated in v1.4" + f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour." + ) + else: + raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.") + + if not isinstance(max_depth, int) or max_depth < -1: + raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.") + + self._max_depth = max_depth + self._layer_summary = self.summarize() + # 1 byte -> 8 bits + # TODO: how do we compute precisin_megabytes in case of mixed precision? + precision = self._model.precision if isinstance(self._model.precision, int) else 32 + self._precision_megabytes = (precision / 8.0) * 1e-6 + + @property + def named_modules(self) -> List[Tuple[str, nn.Module]]: + if self._max_depth == 0: + mods = [] + elif self._max_depth == 1: + # the children are the top-level modules + mods = self._model.named_children() + else: + mods = self._model.named_modules() + mods = list(mods)[1:] # do not include root module (LightningModule) + return list(mods) + + @property + def layer_names(self) -> List[str]: + return list(self._layer_summary.keys()) + + @property + def layer_types(self) -> List[str]: + return [layer.layer_type for layer in self._layer_summary.values()] + + @property + def in_sizes(self) -> List: + return [layer.in_size for layer in self._layer_summary.values()] + + @property + def out_sizes(self) -> List: + return [layer.out_size for layer in self._layer_summary.values()] + + @property + def param_nums(self) -> List[int]: + return [layer.num_parameters for layer in self._layer_summary.values()] + + @property + def total_parameters(self) -> int: + return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters()) + + @property + def trainable_parameters(self) -> int: + return sum( + p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad + ) + + @property + def model_size(self) -> float: + # todo: seems it does not work with quantized models - it returns 0.0 + return self.total_parameters * self._precision_megabytes + + def summarize(self) -> Dict[str, LayerSummary]: + summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) + if self._model.example_input_array is not None: + self._forward_example_input() + for layer in summary.values(): + layer.detach_hook() + + if self._max_depth >= 1: + # remove summary entries with depth > max_depth + for k in [k for k in summary if k.count(".") >= self._max_depth]: + del summary[k] + + return summary + + def _forward_example_input(self) -> None: + """Run the example input through each layer to get input- and output sizes.""" + model = self._model + trainer = self._model.trainer + + input_ = model.example_input_array + input_ = model._apply_batch_transfer_handler(input_) + + if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: + model.forward = torch.cuda.amp.autocast()(model.forward) + + mode = model.training + model.eval() + with torch.no_grad(): + # let the model hooks collect the input- and output shapes + if isinstance(input_, (list, tuple)): + model(*input_) + elif isinstance(input_, dict): + model(**input_) + else: + model(input_) + model.train(mode) # restore mode of module + + def __str__(self): + """ + Makes a summary listing with: + + Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size + """ + arrays = [ + [" ", list(map(str, range(len(self._layer_summary))))], + ["Name", self.layer_names], + ["Type", self.layer_types], + ["Params", list(map(get_human_readable_count, self.param_nums))], + ] + if self._model.example_input_array is not None: + arrays.append(["In sizes", self.in_sizes]) + arrays.append(["Out sizes", self.out_sizes]) + total_parameters = self.total_parameters + trainable_parameters = self.trainable_parameters + model_size = self.model_size + + return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays) + + def __repr__(self): + return str(self) + + +def parse_batch_shape(batch: Any) -> Union[str, List]: + if hasattr(batch, "shape"): + return list(batch.shape) + + if isinstance(batch, (list, tuple)): + shape = [parse_batch_shape(el) for el in batch] + return shape + + return UNKNOWN_SIZE + + +def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str: + """ + Takes in a number of arrays, each specifying a column in + the summary table, and combines them all into one big + string defining the summary table that are nicely formatted. + """ + n_rows = len(cols[0][1]) + n_cols = 1 + len(cols) + + # Get formatting width of each column + col_widths = [] + for c in cols: + col_width = max(len(str(a)) for a in c[1]) if n_rows else 0 + col_width = max(col_width, len(c[0])) # minimum length is header length + col_widths.append(col_width) + + # Formatting + s = "{:<{}}" + total_width = sum(col_widths) + 3 * n_cols + header = [s.format(c[0], l) for c, l in zip(cols, col_widths)] + + # Summary = header + divider + Rest of table + summary = " | ".join(header) + "\n" + "-" * total_width + for i in range(n_rows): + line = [] + for c, l in zip(cols, col_widths): + line.append(s.format(str(c[1][i]), l)) + summary += "\n" + " | ".join(line) + summary += "\n" + "-" * total_width + + summary += "\n" + s.format(get_human_readable_count(trainable_parameters), 10) + summary += "Trainable params" + summary += "\n" + s.format(get_human_readable_count(total_parameters - trainable_parameters), 10) + summary += "Non-trainable params" + summary += "\n" + s.format(get_human_readable_count(total_parameters), 10) + summary += "Total params" + summary += "\n" + s.format(get_formatted_model_size(model_size), 10) + summary += "Total estimated model params size (MB)" + + return summary + + +def get_formatted_model_size(total_model_size: float) -> float: + return f"{total_model_size:,.3f}" + + +def get_human_readable_count(number: int) -> str: + """ + Abbreviates an integer number with K, M, B, T for thousands, millions, + billions and trillions, respectively. + + Examples: + >>> get_human_readable_count(123) + '123 ' + >>> get_human_readable_count(1234) # (one thousand) + '1.2 K' + >>> get_human_readable_count(2e6) # (two million) + '2.0 M' + >>> get_human_readable_count(3e9) # (three billion) + '3.0 B' + >>> get_human_readable_count(4e14) # (four hundred trillion) + '400 T' + >>> get_human_readable_count(5e15) # (more than trillion) + '5,000 T' + + Args: + number: a positive integer number + + Return: + A string formatted according to the pattern described above. + + """ + assert number >= 0 + labels = PARAMETER_NUM_UNITS + num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) + num_groups = int(np.ceil(num_digits / 3)) + num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions + shift = -3 * (num_groups - 1) + number = number * (10 ** shift) + index = num_groups - 1 + if index < 1 or number >= 100: + return f"{int(number):,d} {labels[index]}" + + return f"{number:,.1f} {labels[index]}" + + +def _is_lazy_weight_tensor(p: Tensor) -> bool: + if _TORCH_GREATER_EQUAL_1_8: + from torch.nn.parameter import UninitializedParameter + + if isinstance(p, UninitializedParameter): + warning_cache.warn( + "A layer with UninitializedParameter was found. " + "Thus, the total number of parameters detected may be inaccurate." + ) + return True + return False + + +def summarize( + lightning_module: "pl.LightningModule", mode: Optional[str] = "top", max_depth: Optional[int] = None +) -> Optional[ModelSummary]: + """ + Summarize the LightningModule specified by `lightning_module`. + + Args: + lightning_module: `LightningModule` to summarize. + mode: Can be either ``'top'`` (summarize only direct submodules) or ``'full'`` (summarize all layers). + + .. deprecated:: v1.4 + This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6. + + max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the + layer summary off. Default: 1. + + Return: + The model summary object + """ + + # temporary mapping from mode to max_depth + if max_depth is None: + if mode in ModelSummary.MODES: + max_depth = ModelSummary.MODES[mode] + rank_zero_deprecation( + f"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4" + f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior." + ) + model_summary = ModelSummary(lightning_module, max_depth=max_depth) + elif mode is not None: + raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}") + else: + model_summary = ModelSummary(lightning_module, max_depth=max_depth) + log.info("\n" + str(model_summary)) + return model_summary diff --git a/tests/accelerators/test_ddp_spawn.py b/tests/accelerators/test_ddp_spawn.py index 57c8de2ff14b3..a21078cf55542 100644 --- a/tests/accelerators/test_ddp_spawn.py +++ b/tests/accelerators/test_ddp_spawn.py @@ -14,8 +14,8 @@ import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.core import memory from pytorch_lightning.trainer import Trainer +from pytorch_lightning.utilities import memory from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 51e396741e42f..efaf761cb7116 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -21,7 +21,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.core import memory +from pytorch_lightning.utilities import memory from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.datamodules import ClassifDataModule diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index a363c29456fdc..6c34f58c65df7 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -17,10 +17,10 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.model_summary import ModelSummary from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringDataModule, BoringModel diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index a4d92e8a8d5ed..38b1e50150063 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -12,11 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Test deprecated functionality which will be removed in v1.7.0 """ + import pytest +from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel +def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): + from pytorch_lightning.core.lightning import warning_cache + + model = BoringModel() + model.summarize(max_depth=1) + assert any("The `LightningModule.summarize` method is deprecated in v1.5" in w for w in warning_cache) + warning_cache.clear() + + +def test_v1_7_0_moved_model_summary_and_layer_summary(tmpdir): + _soft_unimport_module("pytorch_lightning.core.memory") + with pytest.deprecated_call(match="to `pytorch_lightning.utilities.model_summary` since v1.5"): + from pytorch_lightning.core.memory import LayerSummary, ModelSummary # noqa: F401 + + +def test_v1_7_0_moved_get_memory_profile_and_get_gpu_memory_map(tmpdir): + _soft_unimport_module("pytorch_lightning.core.memory") + with pytest.deprecated_call(match="to `pytorch_lightning.utilities.memory` since v1.5"): + from pytorch_lightning.core.memory import get_gpu_memory_map, get_memory_profile # noqa: F401 + + def test_v1_7_0_deprecated_model_size(): model = BoringModel() with pytest.deprecated_call( diff --git a/tests/core/test_memory.py b/tests/utilities/test_model_summary.py similarity index 92% rename from tests/core/test_memory.py rename to tests/utilities/test_model_summary.py index e55a83563e7c9..0d993bee18ff2 100644 --- a/tests/core/test_memory.py +++ b/tests/utilities/test_model_summary.py @@ -16,9 +16,9 @@ import torch.nn as nn from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_9 from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_summary import ModelSummary, summarize, UNKNOWN_SIZE from tests.helpers import BoringModel from tests.helpers.advanced_models import ParityModuleRNN from tests.helpers.runif import RunIf @@ -140,7 +140,7 @@ def forward(self, inp): def test_invalid_weights_summmary(): """Test that invalid value for weights_summary raises an error.""" with pytest.raises(MisconfigurationException, match="`mode` can be None, .* got temp"): - UnorderedModel().summarize(mode="temp") + summarize(UnorderedModel, mode="temp") with pytest.raises(MisconfigurationException, match="`weights_summary` can be None, .* got temp"): Trainer(weights_summary="temp") @@ -150,7 +150,7 @@ def test_invalid_weights_summmary(): def test_empty_model_summary_shapes(mode: str): """Test that the summary works for models that have no submodules.""" model = EmptyModule() - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert summary.in_sizes == [] assert summary.out_sizes == [] assert summary.param_nums == [] @@ -166,7 +166,7 @@ def test_linear_model_summary_shapes(device, mode): """Test that the model summary correctly computes the input- and output shapes.""" model = UnorderedModel().to(device) model.train() - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert summary.in_sizes == [[2, 10], [2, 7], [2, 3], [2, 7], UNKNOWN_SIZE] # layer 2 # combine # layer 1 # relu assert summary.out_sizes == [[2, 2], [2, 9], [2, 5], [2, 7], UNKNOWN_SIZE] # layer 2 # combine # layer 1 # relu assert model.training @@ -176,7 +176,7 @@ def test_linear_model_summary_shapes(device, mode): def test_mixed_dtype_model_summary(): """Test that the model summary works with models that have mixed input- and parameter dtypes.""" model = MixedDtypeModel() - summary = model.summarize() + summary = summarize(model) assert summary.in_sizes == [[2, 3], [2, 3, 20]] # embed # reduce assert summary.out_sizes == [[2, 3, 20], [2, 3, 1]] # embed # reduce @@ -205,7 +205,7 @@ def test_rnn_summary_shapes(mode): model.example_input_array = torch.zeros(b, t, 10) - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert summary.in_sizes == [[b, t, i], [b, t, h]] # rnn # linear assert summary.out_sizes == [[[b, t, h], [[1, b, h], [1, b, h]]], [b, t, o]] # rnn # linear @@ -214,7 +214,7 @@ def test_rnn_summary_shapes(mode): def test_summary_parameter_count(mode): """Test that the summary counts the number of parameters in every submodule.""" model = UnorderedModel() - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert summary.param_nums == [ model.layer2.weight.numel() + model.layer2.bias.numel(), model.combine.weight.numel() + model.combine.bias.numel(), @@ -228,14 +228,14 @@ def test_summary_parameter_count(mode): def test_summary_layer_types(mode): """Test that the summary displays the layer names correctly.""" model = UnorderedModel() - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert summary.layer_types == ["Linear", "Linear", "Linear", "ReLU", "Conv2d"] @pytest.mark.parametrize("mode", ["full", "top"]) def test_summary_with_scripted_modules(mode): model = PartialScriptModel() - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert summary.layer_types == ["RecursiveScriptModule", "Linear"] assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]] assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]] @@ -272,7 +272,7 @@ def forward(self, *args, **kwargs): model = DummyLightningModule() model.example_input_array = example_input - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert summary.in_sizes == [expected_size] @@ -280,7 +280,7 @@ def forward(self, *args, **kwargs): def test_model_size(mode): """Test model size is calculated correctly.""" model = PreCalculatedModel() - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert model.pre_calculated_model_size == summary.model_size @@ -288,7 +288,7 @@ def test_model_size(mode): def test_empty_model_size(mode): """Test empty model size is zero.""" model = EmptyModule() - summary = model.summarize(mode=mode) + summary = summarize(model, mode=mode) assert 0.0 == summary.model_size @@ -300,7 +300,7 @@ def test_model_size_precision(tmpdir): # fit model trainer = Trainer(default_root_dir=tmpdir, gpus=1, max_steps=1, max_epochs=1, precision=32) trainer.fit(model) - summary = model.summarize() + summary = summarize(model) assert model.pre_calculated_model_size == summary.model_size @@ -326,15 +326,15 @@ def test_lazy_model_summary(): def test_max_depth_equals_mode_interface(): - """Test model.summarize(full/top) interface mapping matches max_depth""" + """Test summarize(model, full/top) interface mapping matches max_depth""" model = DeepNestedModel() - summary_top = model.summarize(mode="top") - summary_0 = model.summarize(max_depth=1) + summary_top = summarize(model, mode="top") + summary_0 = summarize(model, max_depth=1) assert str(summary_top) == str(summary_0) - summary_full = model.summarize(mode="full") - summary_minus1 = model.summarize(max_depth=-1) + summary_full = summarize(model, mode="full") + summary_minus1 = summarize(model, max_depth=-1) assert str(summary_full) == str(summary_minus1) @@ -351,4 +351,4 @@ def test_max_depth_param(max_depth): @pytest.mark.parametrize("max_depth", [-99, -2, "invalid"]) def test_raise_invalid_max_depth_value(max_depth): with pytest.raises(ValueError, match=f"`max_depth` can be -1, 0 or > 0, got {max_depth}"): - DeepNestedModel().summarize(max_depth=max_depth) + summarize(DeepNestedModel(), max_depth=max_depth)