Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding non-layer param count to summary #17005

Merged
merged 32 commits into from
May 9, 2023

Conversation

rhiga2
Copy link
Contributor

@rhiga2 rhiga2 commented Mar 9, 2023

What does this PR do?

This PR adds the count of parameters not associated to any layer to the model summary. The new model summary should add an extra row to the summary table that specifies the count of params that have not been reported by any layer summary.

Fixes #12736

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Mar 9, 2023
@rhiga2 rhiga2 marked this pull request as ready for review March 9, 2023 20:31
@awaelchli awaelchli added feature Is an improvement or enhancement community This PR is from the community labels Mar 10, 2023
@awaelchli awaelchli added this to the 2.1 milestone Mar 13, 2023
@awaelchli awaelchli self-assigned this Mar 17, 2023
@mergify mergify bot added the ready PRs ready to be merged label May 5, 2023
@awaelchli
Copy link
Contributor

@rhiga2 Could you take a look at your implementation again for this case:

import torch
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.param = torch.nn.Parameter(torch.rand(3))  # this parameter is unused

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        return self(batch).sum()

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    model = BoringModel()
    trainer = Trainer(max_steps=1)
    trainer.fit(model)


if __name__ == "__main__":
    run()

which errors out with

  File "/Users/adrian/repositories/lightning/src/lightning/pytorch/utilities/model_summary/model_summary.py", line 317, in _add_leftover_params_to_summary
    layer_summaries["In sizes"].append(NOT_APPLICABLE)
KeyError: 'In sizes'

Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall approach LGTM, but we are missing the regular use case when no example input array is provided. Let's extend the test case for that and handle the optional columns correctly.

@mergify mergify bot removed the ready PRs ready to be merged label May 6, 2023
@rhiga2 rhiga2 requested a review from awaelchli May 6, 2023 01:26
@rhiga2 rhiga2 requested a review from awaelchli May 9, 2023 02:21
@mergify mergify bot added the ready PRs ready to be merged label May 9, 2023
@awaelchli awaelchli enabled auto-merge (squash) May 9, 2023 07:52
@awaelchli awaelchli merged commit 300abb3 into Lightning-AI:master May 9, 2023
@rhiga2 rhiga2 deleted the feature/non_layer_model_summary branch May 25, 2023 00:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback: model summary community This PR is from the community feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ModelSummary ignores nn.Parameter
4 participants