-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved model initialization API for Fabric #17462
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…ic/half-precision
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
[pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci
076973f
to
1805a60
Compare
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job. Since you might prefer to address some of my comments in a follow-up, approving to unblock.
We should also add this to the fabric docs
model = MyModel() | ||
@contextmanager | ||
def init_module(self) -> Generator: | ||
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might also be used in the future for things that arent modules, maybe we should choose a different name that doesnt explicitly state "module"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_module
is fine IMO, but if we want to discuss names :-) I find that the name doesn't convey what is different from just initializing outside the context manager.
Maybe direct_init
would be an alternative? I'm good with init_module
though, it probably sounds less cryptic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking mentioning the "module" in the name is safer because it expresses more precisely what we should do under this context manager. This hopefully helps prevent users from accidentally doing weird stuff under this context manager and getting in trouble, but of course, we can never completely prevent that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would vote for something like efficient_init
or similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, this context manager woudl be useful for non-modules too:
import torch
from lightning import Fabric
fabric = Fabric(accelerator="cuda", precision="64-true")
with fabric.efficient_init():
x = torch.zeros(1)
print(x.device) # cuda
print(x.dtype) # 64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is very important that this context manager be only used for model initialization. For everything else, the user should use fabric.device
.
The main motivation for this is to load and shard large models efficiently, or to provide a convenient way to cast the model to the desired dtype for inference without code changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great
with self.sharded_model(), _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods( | ||
BatchSampler | ||
): | ||
with _old_sharded_model_context(self._strategy), _replace_dunder_methods( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There wouldn't be any impact for current Fabric end-users (not developers) right? If so, I agree we should remove.
model = MyModel() | ||
@contextmanager | ||
def init_module(self) -> Generator: | ||
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_module
is fine IMO, but if we want to discuss names :-) I find that the name doesn't convey what is different from just initializing outside the context manager.
Maybe direct_init
would be an alternative? I'm good with init_module
though, it probably sounds less cryptic.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #17462 +/- ##
=========================================
- Coverage 83% 59% -24%
=========================================
Files 415 410 -5
Lines 31649 31594 -55
=========================================
- Hits 26405 18708 -7697
- Misses 5244 12886 +7642 |
What does this PR do?
Adds a new context manager in Fabric that allows you to init your model:
FakeTensorMode
for FSDP #16448)This context manager genralized the previous
sharded_model
manager from the LightningLite-days (we don't need both).Example 1:
Init the model in the GPU instantly and with weights in half precision (your model may not fit in float32).
See #17287 for half support.
Example 2:
Init FSDP model on the meta device (or using torchdistx).
Example 3:
(Future) Init the model with empty weights explicitly (no memory allocated) if you later need to overwite by loading a checkpoint anyway:
This API can be used for #16448 or to support torchdistx.
cc @Borda @carmocca @justusschock @awaelchli