Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved model initialization API for Fabric #17462

Merged
merged 45 commits into from
Apr 26, 2023
Merged

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Apr 24, 2023

What does this PR do?

Adds a new context manager in Fabric that allows you to init your model:

  • directly on the right device
  • with the right device type
  • getting sharded instantly
  • allow for empty weight initialization (meta device) or deferred init (torchdistx) (Adopt FakeTensorMode for FSDP #16448)

This context manager genralized the previous sharded_model manager from the LightningLite-days (we don't need both).

Example 1:

Init the model in the GPU instantly and with weights in half precision (your model may not fit in float32).

fabric = Fabric(accelerator="cuda", precision="bf16-true")

with fabric.init_module():
    model = MyModel()

See #17287 for half support.

Example 2:
Init FSDP model on the meta device (or using torchdistx).

fabric = Fabric(accelerator="cuda", strategy="fsdp")

with fabric.init_module():
    model = MyModel()  # model params are now on meta device (no memory allocated)

model = fabric.setup(model)  # params get sharded and put on device

Example 3:
(Future) Init the model with empty weights explicitly (no memory allocated) if you later need to overwite by loading a checkpoint anyway:

fabric = Fabric(...)

with fabric.init_module(empty_weights=True):
    model = MyModel()
    model.load_state_dict(...)

This API can be used for #16448 or to support torchdistx.

cc @Borda @carmocca @justusschock @awaelchli

@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Apr 24, 2023
[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job. Since you might prefer to address some of my comments in a follow-up, approving to unblock.

We should also add this to the fabric docs

model = MyModel()
@contextmanager
def init_module(self) -> Generator:
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might also be used in the future for things that arent modules, maybe we should choose a different name that doesnt explicitly state "module"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_module is fine IMO, but if we want to discuss names :-) I find that the name doesn't convey what is different from just initializing outside the context manager.

Maybe direct_init would be an alternative? I'm good with init_module though, it probably sounds less cryptic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking mentioning the "module" in the name is safer because it expresses more precisely what we should do under this context manager. This hopefully helps prevent users from accidentally doing weird stuff under this context manager and getting in trouble, but of course, we can never completely prevent that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would vote for something like efficient_init or similar

Copy link
Contributor

@carmocca carmocca Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, this context manager woudl be useful for non-modules too:

import torch
from lightning import Fabric

fabric = Fabric(accelerator="cuda", precision="64-true")
with fabric.efficient_init():
    x = torch.zeros(1)
    print(x.device)  # cuda
    print(x.dtype)  # 64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very important that this context manager be only used for model initialization. For everything else, the user should use fabric.device.

The main motivation for this is to load and shard large models efficiently, or to provide a convenient way to cast the model to the desired dtype for inference without code changes.

Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great

with self.sharded_model(), _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(
BatchSampler
):
with _old_sharded_model_context(self._strategy), _replace_dunder_methods(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There wouldn't be any impact for current Fabric end-users (not developers) right? If so, I agree we should remove.

src/lightning/fabric/strategies/strategy.py Show resolved Hide resolved
model = MyModel()
@contextmanager
def init_module(self) -> Generator:
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_module is fine IMO, but if we want to discuss names :-) I find that the name doesn't convey what is different from just initializing outside the context manager.

Maybe direct_init would be an alternative? I'm good with init_module though, it probably sounds less cryptic.

@mergify mergify bot added the ready PRs ready to be merged label Apr 26, 2023
@codecov
Copy link

codecov bot commented Apr 26, 2023

Codecov Report

Merging #17462 (758fbd7) into master (d48ec08) will decrease coverage by 24%.
The diff coverage is 89%.

Additional details and impacted files
@@            Coverage Diff            @@
##           master   #17462     +/-   ##
=========================================
- Coverage      83%      59%    -24%     
=========================================
  Files         415      410      -5     
  Lines       31649    31594     -55     
=========================================
- Hits        26405    18708   -7697     
- Misses       5244    12886   +7642     

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants