Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model summary when using DeepSpeed Stage 3 #13427

Merged
merged 14 commits into from
Jun 29, 2022
Merged

Conversation

SeanNaren
Copy link
Contributor

@SeanNaren SeanNaren commented Jun 28, 2022

What does this PR do?

Fixes #12130

Introduces a DeepSpeedModelSummary that includes logic to take out the actual size of tensors + show you the size of partitions made by DeepSpeed. Previously the weights were "0" due to DeepSpeed changing the parameters in place. Now you get something like this:

  | Name  | Type                       | Params | Params per Device
-------------------------------------------------------------------------
0 | ptlm  | T5ForConditionalGeneration | 737 M  | 184 M
1 | layer | Linear                     | 66     | 17
-------------------------------------------------------------------------
737 M    Trainable params
0        Non-trainable params
737 M    Total params

Does your PR introduce any breaking changes? If yes, please list them.

None

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

cc @Borda @awaelchli @ananthsub @rohitgr7 @SeanNaren @akihironitta

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition!

@carmocca carmocca added feature Is an improvement or enhancement callback and removed bug Something isn't working labels Jun 28, 2022
@SeanNaren SeanNaren modified the milestones: pl:1.6.x, pl:1.7 Jun 28, 2022
@mergify mergify bot added the ready PRs ready to be merged label Jun 28, 2022
Copy link
Contributor

@rohitgr7 rohitgr7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also fixes: #10291?

src/pytorch_lightning/callbacks/deepspeed_model_summary.py Outdated Show resolved Hide resolved
@SeanNaren
Copy link
Contributor Author

Also fixes: #10291?

Doesn't handle FSDP right now, we can add that after!

@SeanNaren SeanNaren enabled auto-merge (squash) June 29, 2022 14:08
@SeanNaren SeanNaren merged commit f145acd into master Jun 29, 2022
@SeanNaren SeanNaren deleted the feat/deepspeed_summary branch June 29, 2022 14:49
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I had the review done yesterday but didn't submit it :(

@@ -280,6 +280,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))


- Fixed Model Summary when using DeepSpeed Stage 3 ([#13427](https://github.com/PyTorchLightning/pytorch-lightning/pull/13427))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest either

Suggested change
- Fixed Model Summary when using DeepSpeed Stage 3 ([#13427](https://github.com/PyTorchLightning/pytorch-lightning/pull/13427))
- Fixed model summary when using DeepSpeed Stage 3 ([#13427](https://github.com/PyTorchLightning/pytorch-lightning/pull/13427))

or

Suggested change
- Fixed Model Summary when using DeepSpeed Stage 3 ([#13427](https://github.com/PyTorchLightning/pytorch-lightning/pull/13427))
- Fixed ModelSummary callback when using DeepSpeed Stage 3 ([#13427](https://github.com/PyTorchLightning/pytorch-lightning/pull/13427))

@@ -0,0 +1,94 @@
#!/usr/bin/env python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for organisation, it would be nice if it was called model_summary_deepspeed or if it was grouped under a folder model_summary



@RunIf(min_cuda_gpus=2, deepspeed=True, standalone=True)
def test_deepspeed_summary(tmpdir):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is very expensive for just testing a model summary.
Is there no way we can just test the summary on a model directly without full training and launching processes?

jerome-habana pushed a commit to jerome-habana/lightning that referenced this pull request Jul 14, 2022
@awaelchli awaelchli mentioned this pull request Jul 27, 2022
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback feature Is an improvement or enhancement ready PRs ready to be merged strategy: deepspeed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wrong number of trainable parameters printed with strategy="deepspeed_stage_3"
6 participants