Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Count number of modules in train/eval mode in ModelSummary #20159

Merged
merged 12 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-pytorch/advanced/transfer_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
super().__init__()

self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
self.bert.train()
self.W = nn.Linear(bert.config.hidden_size, 3)
self.num_classes = 3

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578))

- Added the count of modules in train and eval mode to the printed `ModelSummary` table ([#20159](https://github.com/Lightning-AI/pytorch-lightning/pull/20159))

### Changed

- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))
Expand Down
12 changes: 11 additions & 1 deletion src/lightning/pytorch/callbacks/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
total_parameters = model_summary.total_parameters
trainable_parameters = model_summary.trainable_parameters
model_size = model_summary.model_size
total_training_modes = model_summary.total_training_modes

if trainer.is_global_zero:
self.summarize(summary_data, total_parameters, trainable_parameters, model_size, **self._summarize_kwargs)
self.summarize(
summary_data,
total_parameters,
trainable_parameters,
model_size,
total_training_modes,
**self._summarize_kwargs,
)

def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Union[DeepSpeedSummary, Summary]:
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
Expand All @@ -83,12 +91,14 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes: Dict[str, int],
**summarize_kwargs: Any,
) -> None:
summary_table = _format_summary_table(
total_parameters,
trainable_parameters,
model_size,
total_training_modes,
*summary_data,
)
log.info("\n" + summary_table)
5 changes: 4 additions & 1 deletion src/lightning/pytorch/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple

from typing_extensions import override

Expand Down Expand Up @@ -71,6 +71,7 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes: Dict[str, int],
**summarize_kwargs: Any,
) -> None:
from rich import get_console
Expand Down Expand Up @@ -110,5 +111,7 @@ def summarize(
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
grid.add_row(f"[bold]Modules in train mode[/]: {total_training_modes['train']}")
grid.add_row(f"[bold]Modules in eval mode[/]: {total_training_modes['eval']}")

console.print(grid)
18 changes: 17 additions & 1 deletion src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ class ModelSummary:
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
3 Modules in train mode
0 Modules in eval mode
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | Mode | In sizes | Out sizes
----------------------------------------------------------------------
Expand All @@ -198,6 +200,8 @@ class ModelSummary:
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
3 Modules in train mode
0 Modules in eval mode

"""

Expand Down Expand Up @@ -252,6 +256,12 @@ def param_nums(self) -> List[int]:
def training_modes(self) -> List[bool]:
return [layer.training for layer in self._layer_summary.values()]

@property
def total_training_modes(self) -> Dict[str, int]:
modes = [layer.training for layer in self._model.modules()]
modes = modes[1:] # exclude the root module
return {"train": modes.count(True), "eval": modes.count(False)}

@property
def total_parameters(self) -> int:
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
Expand Down Expand Up @@ -351,8 +361,9 @@ def __str__(self) -> str:
total_parameters = self.total_parameters
trainable_parameters = self.trainable_parameters
model_size = self.model_size
total_training_modes = self.total_training_modes

return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays)
return _format_summary_table(total_parameters, trainable_parameters, model_size, total_training_modes, *arrays)

def __repr__(self) -> str:
return str(self)
Expand All @@ -372,6 +383,7 @@ def _format_summary_table(
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes: Dict[str, int],
*cols: Tuple[str, List[str]],
) -> str:
"""Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
Expand Down Expand Up @@ -408,6 +420,10 @@ def _format_summary_table(
summary += "Total params"
summary += "\n" + s.format(get_formatted_model_size(model_size), 10)
summary += "Total estimated model params size (MB)"
summary += "\n" + s.format(total_training_modes["train"], 10)
summary += "Modules in train mode"
summary += "\n" + s.format(total_training_modes["eval"], 10)
summary += "Modules in eval mode"

return summary

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def on_train_epoch_end(self, trainer, pl_module):
self.saved_states.append(self.state_dict().copy())


@RunIf(sklearn=True)
@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_resume_early_stopping_from_checkpoint(tmp_path):
"""Prevent regressions to bugs:
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes,
**summarize_kwargs: Any,
) -> None:
assert summary_data[1][0] == "Name"
Expand All @@ -64,6 +65,8 @@ def summarize(
assert summary_data[4][0] == "Mode"
assert summary_data[4][1][0] == "train"

assert total_training_modes == {"train": 1, "eval": 0}

model = BoringModel()
trainer = Trainer(default_root_dir=tmp_path, callbacks=CustomModelSummary(), max_steps=1)

Expand Down
8 changes: 7 additions & 1 deletion tests/tests_pytorch/callbacks/test_rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def example_input_array(self) -> Any:
summary = summarize(model)
summary_data = summary._get_summary_data()

model_summary.summarize(summary_data=summary_data, total_parameters=1, trainable_parameters=1, model_size=1)
model_summary.summarize(
summary_data=summary_data,
total_parameters=1,
trainable_parameters=1,
model_size=1,
total_training_modes=summary.total_training_modes,
)

# ensure that summary was logged + the breakdown of model parameters
assert mock_console.call_count == 2
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
assert dm.my_state_dict == {"my": "state_dict"}


@RunIf(sklearn=True)
@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons
def test_full_loop(tmp_path):
seed_everything(7)

Expand Down
28 changes: 27 additions & 1 deletion tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,29 @@ def forward(self, x):
assert not model.layer2.training


def test_total_training_modes():
"""Test that the `total_training_modes` counts the modules in 'train' and 'eval' mode, excluding the root
module."""

class ModelWithoutChildren(LightningModule):
pass

summary = ModelSummary(ModelWithoutChildren())
assert summary.total_training_modes == {"train": 0, "eval": 0}

model = DeepNestedModel()
summary = ModelSummary(model)
assert summary.total_training_modes == {"train": 19, "eval": 0}
assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1

model = DeepNestedModel()
summary = ModelSummary(model)
model.branch1[1][0].eval()
model.branch2.eval()
assert summary.total_training_modes == {"train": 17, "eval": 2}
assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1


def test_summary_training_mode():
"""Test that the model summary captures the training mode on all submodules."""
model = DeepNestedModel()
Expand All @@ -436,6 +459,7 @@ def test_summary_training_mode():
"eval", # branch2
"train", # head
]
assert summary.total_training_modes == {"train": 17, "eval": 2}

summary = summarize(model, max_depth=-1)
expected_eval = {"branch1.1.0", "branch2"}
Expand All @@ -445,5 +469,7 @@ def test_summary_training_mode():
# A model with params not belonging to a layer
model = NonLayerParamsModel()
model.layer.eval()
summary_data = OrderedDict(summarize(model)._get_summary_data())
summary = summarize(model)
summary_data = OrderedDict(summary._get_summary_data())
assert summary_data["Mode"] == ["eval", "n/a"]
assert summary.total_training_modes == {"train": 0, "eval": 1}
Loading