-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support empty weight initialization in Fabric.init_module()
#17627
Conversation
e915926
to
4910cb6
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #17627 +/- ##
=========================================
- Coverage 84% 61% -23%
=========================================
Files 419 415 -4
Lines 31721 31662 -59
=========================================
- Hits 26634 19382 -7252
- Misses 5087 12280 +7193 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks good.
I would personally go for empty_init
.
for more information, see https://pre-commit.ci
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465 | ||
# is resolved. For now, the module will get moved to the device in `setup_module`. | ||
with self.precision.init_context(), self.module_sharded_context(): | ||
empty_init_context = ( | ||
_EmptyInit(enabled=(empty_init is not False)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awaelchli We might want to reconsider doing empty_init=True
by default on FSDP. I just went down a rabbit hole chasing NaNs on FSDP when all I needed to change was empty_init=False
. Since the default was to skip initialization and I wasn't loading a full checkpoint, the weights were initialized to garbage.
Since there's no way for us to know if the user will load a full checkpoint after initialization, I find it safer to not do this by default.
context: Lightning-AI/litgpt#193
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That wasn't the intention, sorry. The plan was to do this automatically if we have fake tensor / support for materialization (torchdistx-style). In the meantime, it would be safer to do:
_EmptyInit(enabled=(empty_init is not False)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext() | |
_EmptyInit(enabled=bool(empty_init) if _TORCH_GREATER_EQUAL_1_13 else nullcontext() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment above there
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465
is what I ultimately wanted to achieve
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same. While it's a cool feature, having this default to true is very disruptive and hard to debug.
What does this PR do?
Fixes #17616
Adds a toggle on the
Fabric.init_module
that speeds up initialization and memory allocation for a large model.Useful for finetuning / loading weights into a large model.
See how
Fabric.init_module(empty_weights=True)
can be applied in lit-llama to minimize boilerplate logic: Lightning-AI/lit-llama#360cc @Borda @carmocca @justusschock @awaelchli