Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support empty weight initialization in Fabric.init_module() #17627

Merged
merged 77 commits into from
Jun 7, 2023

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented May 14, 2023

What does this PR do?

Fixes #17616

Adds a toggle on the Fabric.init_module that speeds up initialization and memory allocation for a large model.

with fabric.init_module(empty_weights=True):
    # it is very fast to initialize, and depending on the strategy allocates no memory, or uninitialized memory
    model = MyModel()
# weights get loaded into the model
model.load_state_dict(checkpoint["state_dict"])

Useful for finetuning / loading weights into a large model.

See how Fabric.init_module(empty_weights=True) can be applied in lit-llama to minimize boilerplate logic: Lightning-AI/lit-llama#360

cc @Borda @carmocca @justusschock @awaelchli

@github-actions github-actions bot added app (removed) Generic label for Lightning App package ci Continuous Integration fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels May 14, 2023
@awaelchli awaelchli force-pushed the fabric/empty-init branch from e915926 to 4910cb6 Compare May 14, 2023 15:54
@github-actions github-actions bot removed pl Generic label for PyTorch Lightning package ci Continuous Integration app (removed) Generic label for Lightning App package labels May 14, 2023
src/lightning/fabric/fabric.py Outdated Show resolved Hide resolved
src/lightning/fabric/strategies/deepspeed.py Show resolved Hide resolved
src/lightning/fabric/utilities/init.py Show resolved Hide resolved
src/lightning/fabric/strategies/fsdp.py Show resolved Hide resolved
tests/tests_fabric/utilities/test_init.py Outdated Show resolved Hide resolved
tests/tests_fabric/utilities/test_init.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Jun 6, 2023

Codecov Report

Merging #17627 (0c7fda9) into master (420eb6f) will decrease coverage by 23%.
The diff coverage is 79%.

❗ Current head 0c7fda9 differs from pull request most recent head 4d71cdd. Consider uploading reports for the commit 4d71cdd to get more accurate results

Additional details and impacted files
@@            Coverage Diff            @@
##           master   #17627     +/-   ##
=========================================
- Coverage      84%      61%    -23%     
=========================================
  Files         419      415      -4     
  Lines       31721    31662     -59     
=========================================
- Hits        26634    19382   -7252     
- Misses       5087    12280   +7193     

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks good.

I would personally go for empty_init.

src/lightning/fabric/utilities/init.py Show resolved Hide resolved
@mergify mergify bot added has conflicts and removed ready PRs ready to be merged labels Jun 6, 2023
@mergify mergify bot added ready PRs ready to be merged and removed has conflicts ready PRs ready to be merged labels Jun 6, 2023
@awaelchli awaelchli enabled auto-merge (squash) June 7, 2023 10:41
src/lightning/fabric/CHANGELOG.md Outdated Show resolved Hide resolved
docs/source-fabric/api/fabric_methods.rst Outdated Show resolved Hide resolved
@awaelchli awaelchli merged commit 24a3115 into master Jun 7, 2023
@awaelchli awaelchli deleted the fabric/empty-init branch June 7, 2023 18:33
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465
# is resolved. For now, the module will get moved to the device in `setup_module`.
with self.precision.init_context(), self.module_sharded_context():
empty_init_context = (
_EmptyInit(enabled=(empty_init is not False)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
Copy link
Contributor

@carmocca carmocca Jun 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli We might want to reconsider doing empty_init=True by default on FSDP. I just went down a rabbit hole chasing NaNs on FSDP when all I needed to change was empty_init=False. Since the default was to skip initialization and I wasn't loading a full checkpoint, the weights were initialized to garbage.

Since there's no way for us to know if the user will load a full checkpoint after initialization, I find it safer to not do this by default.

context: Lightning-AI/litgpt#193

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That wasn't the intention, sorry. The plan was to do this automatically if we have fake tensor / support for materialization (torchdistx-style). In the meantime, it would be safer to do:

Suggested change
_EmptyInit(enabled=(empty_init is not False)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
_EmptyInit(enabled=bool(empty_init) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above there

    # TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465

is what I ultimately wanted to achieve

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. While it's a cool feature, having this default to true is very disruptive and hard to debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Empty weight initialization through Fabric.init_module()
5 participants