Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Count number of modules in train/eval mode in ModelSummary #20159

Merged
merged 12 commits into from
Aug 4, 2024

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Aug 4, 2024

What does this PR do?

Fixes #19820
Fixes #20128

This PR adds two rows to the model summary that count how many modules are in train and how many are in eval model. The issues linked above raised concern that it is not visible enough when models are in eval mode (accidentally). Printing the explicit count for each mode should raise awareness:

  1. For pretraining, you want it to show 0 for the total eval modules, and the rest in train mode
  2. For finetuning, either 0 eval modules or a mix of train and eval (some modules frozen).

To see which modules are in train/eval mode, the user can look at the summary table or expand it to show all modules using ModelSummary(max_depth=-1).

Example:

import torch
from transformers import BertModel

from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2).train()
        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        # self.bert.train()

    def training_step(self, batch, batch_idx):
        return self.layer(batch).sum()

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 64)))

Output:

  | Name  | Type      | Params | Mode 
--------------------------------------------
0 | layer | Linear    | 66     | train
1 | bert  | BertModel | 108 M  | eval 
--------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
433.241   Total estimated model params size (MB)
1         Modules in train mode                               <--- New
228       Modules in eval mode                                <--- New

New are the two rows at the bottom that count the modules in train/eval mode.

Alternatives

#19820 initially proposed to print a warning. However, that would lead to false positives when doing finetuning, and so I opted for the info in the model summary (which is on by default in the Trainer).


📚 Documentation preview 📚: https://pytorch-lightning--20159.org.readthedocs.build/en/20159/

cc @Borda @awaelchli

@github-actions github-actions bot added docs Documentation related pl Generic label for PyTorch Lightning package labels Aug 4, 2024
@awaelchli awaelchli marked this pull request as ready for review August 4, 2024 09:40
@awaelchli awaelchli added feature Is an improvement or enhancement callback: model summary labels Aug 4, 2024
@awaelchli awaelchli added this to the 2.4 milestone Aug 4, 2024
@awaelchli awaelchli added the fun Staff contributions outside working hours - to differentiate from the "community" label label Aug 4, 2024
Copy link
Contributor

github-actions bot commented Aug 4, 2024

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-13, lightning, 3.9, 2.1, oldest) success
pl-cpu (macOS-14, lightning, 3.10, 2.1) success
pl-cpu (macOS-14, lightning, 3.11, 2.2) success
pl-cpu (macOS-14, lightning, 3.11, 2.3) success
pl-cpu (macOS-14, lightning, 3.12, 2.4) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2) success
pl-cpu (ubuntu-20.04, lightning, 3.11, 2.3) success
pl-cpu (ubuntu-20.04, lightning, 3.12, 2.4) success
pl-cpu (windows-2022, lightning, 3.9, 2.1, oldest) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.11, 2.2) success
pl-cpu (windows-2022, lightning, 3.11, 2.3) success
pl-cpu (windows-2022, lightning, 3.12, 2.4) success
pl-cpu (macOS-14, pytorch, 3.9, 2.1) success
pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1) success
pl-cpu (windows-2022, pytorch, 3.9, 2.1) success
pl-cpu (macOS-12, pytorch, 3.10, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.10, 2.1) success
pl-cpu (windows-2022, pytorch, 3.10, 2.1) success

These checks are required after the changes to src/lightning/pytorch/callbacks/model_summary.py, src/lightning/pytorch/callbacks/rich_model_summary.py, src/lightning/pytorch/utilities/model_summary/model_summary.py, tests/tests_pytorch/callbacks/test_early_stopping.py, tests/tests_pytorch/callbacks/test_model_summary.py, tests/tests_pytorch/callbacks/test_rich_model_summary.py, tests/tests_pytorch/core/test_datamodules.py, tests/tests_pytorch/utilities/test_model_summary.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to src/lightning/pytorch/callbacks/model_summary.py, src/lightning/pytorch/callbacks/rich_model_summary.py, src/lightning/pytorch/utilities/model_summary/model_summary.py, tests/tests_pytorch/callbacks/test_early_stopping.py, tests/tests_pytorch/callbacks/test_model_summary.py, tests/tests_pytorch/callbacks/test_rich_model_summary.py, tests/tests_pytorch/core/test_datamodules.py, tests/tests_pytorch/utilities/test_model_summary.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/pytorch/callbacks/model_summary.py, src/lightning/pytorch/callbacks/rich_model_summary.py, src/lightning/pytorch/utilities/model_summary/model_summary.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to src/lightning/pytorch/callbacks/model_summary.py, src/lightning/pytorch/callbacks/rich_model_summary.py, src/lightning/pytorch/utilities/model_summary/model_summary.py, docs/source-pytorch/advanced/transfer_learning.rst.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/pytorch/callbacks/model_summary.py, src/lightning/pytorch/callbacks/rich_model_summary.py, src/lightning/pytorch/utilities/model_summary/model_summary.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, fabric, 3.9) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.9) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.9) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.9) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, fabric, 3.9) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.9) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.9) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.9) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, fabric, 3.9) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.9) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.9) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.9) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/pytorch/callbacks/model_summary.py, src/lightning/pytorch/callbacks/rich_model_summary.py, src/lightning/pytorch/utilities/model_summary/model_summary.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

Copy link

codecov bot commented Aug 4, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81%. Comparing base (2638d82) to head (67bf3dc).
Report is 37 commits behind head on master.

❗ There is a different number of reports uploaded between BASE (2638d82) and HEAD (67bf3dc). Click for more details.

HEAD has 90 uploads less than BASE
Flag BASE (2638d82) HEAD (67bf3dc)
gpu 4 2
pytest 25 2
lightning_fabric 7 0
lightning 32 16
cpu 42 21
python3.9 12 6
python3.10 12 6
python3.11 12 6
python3.12 6 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #20159     +/-   ##
=========================================
- Coverage      89%      81%     -8%     
=========================================
  Files         267      264      -3     
  Lines       23050    23008     -42     
=========================================
- Hits        20552    18601   -1951     
- Misses       2498     4407   +1909     

@awaelchli awaelchli merged commit d4de8e2 into master Aug 4, 2024
74 of 75 checks passed
@awaelchli awaelchli deleted the feature/summarize-train-mode-count branch August 4, 2024 19:28
ammyk9 pushed a commit to ammyk9/pytorch-lightning that referenced this pull request Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback: model summary docs Documentation related feature Is an improvement or enhancement fun Staff contributions outside working hours - to differentiate from the "community" label pl Generic label for PyTorch Lightning package
Projects
None yet
2 participants