Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support 4bit BNB layers meta-device materialization #19150

Merged
merged 17 commits into from
Dec 20, 2023
Merged

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Dec 13, 2023

What does this PR do?

Adds support for converting Linear layers on meta-device to Bitsandbytes Linear layers.

Also for materializing and quantizing Bitsandbytes Linear layers on meta-device.

Scenario 1

fabric = Fabric(plugins=BitsandbytesPrecision(), devices=1)
# empty_init=True with devices=1 doesn't use meta device at the moment
with fabric.init_module(empty_init=False), torch.device("meta"):
    model = LLM()
# this looks stupid, but in reality this uses a custom materialization scheme
materialize_meta_tensors(model, fabric.device)  # quantize here
maybe_install_hooks(model)
load_checkpoint(model)  # quantize here

Cons:

  • to_empty and reset_parameters (called from materialize_meta_tensors) will create empty tensors and quantize. There's no way to avoid this unless we make materialize_meta_tensors aware of bitsandbytes.
  • load_checkpoint will again quantize weights that were already quantized during materialization.
  • we materialized model parts that were going to be loaded anyway

Scenario 2

fabric = Fabric(plugins=BitsandbytesPrecision(), devices=1)
with fabric.init_module(empty_init=False), torch.device("meta"):
    model = LLM()
maybe_install_hooks(model)
load_checkpoint(model)  # quantize here
materialize_meta_tensors(model, fabric.device)  # quantize here
model = fabric.setup(model, move_to_device=False)
compile(model)

I believe this is the ideal scenario.

Scenario 3

fabric = Fabric(plugins=BitsandbytesPrecision(), devices=1)
with torch.device("meta"):
    model = LLM()
maybe_install_hooks(model)
load_checkpoint(model)
model = fabric.setup(model, move_to_device=False)  # do not quantize, just replace
materialize_meta_tensors(model, fabric.device)  # quantize here

Cons:

  • If the checkpoint is complete, loading will OOM.
  • fabric.setup will need to recreate layers so the model hooks are lost.
  • to_empty and reset_parameters (called from materialize_meta_tensors) will both create empty tensors and quantize. There's no way to avoid this unless we make materialize_meta_tensors aware of bitsandbytes.

Scenario 4

fabric = Fabric(plugins=BitsandbytesPrecision(), devices=1)
with torch.device("meta"):
    model = LLM()
materialize_meta_tensors(model, fabric.device)
model = fabric.setup(model, move_to_device=False)  # do not quantize, just replace
maybe_install_hooks(model)
load_checkpoint(model)  # quantize here
quantize_unloaded_layers(model)  # quantize here

Cons:

  • If the checkpoint is incomplete, an extra quantization step would be required. (unimplemented)

8-bit layer materialization is not implemented. I only made the minimal changes required for it.


📚 Documentation preview 📚: https://pytorch-lightning--19150.org.readthedocs.build/en/19150/

cc @Borda @carmocca @justusschock @awaelchli

@carmocca carmocca added feature Is an improvement or enhancement precision: bnb Bitsandbytes quantization labels Dec 13, 2023
@carmocca carmocca added this to the 2.2 milestone Dec 13, 2023
@carmocca carmocca self-assigned this Dec 13, 2023
@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package dependencies Pull requests that update a dependency file labels Dec 13, 2023
@carmocca carmocca marked this pull request as ready for review December 14, 2023 16:51
@carmocca carmocca requested a review from tchaton as a code owner December 14, 2023 16:51
Copy link
Contributor

github-actions bot commented Dec 14, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/plugins/precision/bitsandbytes.py, src/lightning/fabric/utilities/init.py, requirements/pytorch/extra.txt.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to requirements/pytorch/extra.txt, src/lightning/fabric/plugins/precision/bitsandbytes.py, src/lightning/fabric/utilities/init.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to requirements/pytorch/extra.txt, src/lightning/fabric/plugins/precision/bitsandbytes.py, src/lightning/fabric/utilities/init.py.

🟢 fabric: Docs
Check ID Status
docs-make (fabric, doctest) success
docs-make (fabric, html) success

These checks are required after the changes to src/lightning/fabric/plugins/precision/bitsandbytes.py, src/lightning/fabric/utilities/init.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to requirements/pytorch/extra.txt.

🟢 pytorch_lightning: Docker
Check ID Status
build-cuda (3.9, 1.12, 11.7.1) success
build-cuda (3.9, 1.13, 11.8.0) success
build-cuda (3.9, 1.13, 12.0.1) success
build-cuda (3.10, 2.0, 11.8.0) success
build-cuda (3.10, 2.1, 12.1.0) success
build-pl (3.9, 1.12, 11.7.1) success
build-pl (3.9, 1.13, 11.8.0) success
build-pl (3.9, 1.13, 12.0.1) success
build-pl (3.10, 2.0, 11.8.0) success
build-pl (3.10, 2.1, 12.1.0) success

These checks are required after the changes to requirements/pytorch/extra.txt.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.11, 2.1) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1) success
fabric-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.11, 2.1) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success
fabric-cpu (macOS-12, fabric, 3.11, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.1) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1) success
fabric-cpu (windows-2022, fabric, 3.11, 2.0) success
fabric-cpu (windows-2022, fabric, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/plugins/precision/bitsandbytes.py, src/lightning/fabric/utilities/init.py, tests/tests_fabric/plugins/precision/test_bitsandbytes.py, tests/tests_fabric/utilities/test_init.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) (testing Fabric | latest) success
lightning-fabric (GPUs) (testing Lightning | latest) success

These checks are required after the changes to src/lightning/fabric/plugins/precision/bitsandbytes.py, src/lightning/fabric/utilities/init.py, tests/tests_fabric/plugins/precision/test_bitsandbytes.py, tests/tests_fabric/utilities/test_init.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to requirements/pytorch/extra.txt, src/lightning/fabric/plugins/precision/bitsandbytes.py, src/lightning/fabric/utilities/init.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/fabric/plugins/precision/bitsandbytes.py, src/lightning/fabric/utilities/init.py, requirements/pytorch/extra.txt.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

Copy link

codecov bot commented Dec 14, 2023

Codecov Report

Merging #19150 (4b67966) into master (3b1643c) will decrease coverage by 29%.
Report is 6 commits behind head on master.
The diff coverage is 42%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19150      +/-   ##
==========================================
- Coverage      83%      54%     -29%     
==========================================
  Files         443      439       -4     
  Lines       36859    36984     +125     
==========================================
- Hits        30539    19940   -10599     
- Misses       6320    17044   +10724     

@mergify mergify bot removed the has conflicts label Dec 14, 2023
@mergify mergify bot added the ready PRs ready to be merged label Dec 20, 2023
@carmocca carmocca merged commit 6dfa5cc into master Dec 20, 2023
129 checks passed
@carmocca carmocca deleted the carmocca/bnb-metad branch December 20, 2023 21:13
@@ -37,7 +39,8 @@

log = logging.getLogger(__name__)

_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes>=0.41.0")
# TODO: unpin after resolving the `quant_state` format breaking changes
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes==0.41.0")
Copy link
Contributor

@Andrei-Aksionov Andrei-Aksionov Dec 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still new to this repo, so maybe I misunderstood something.

Not sure how it can work.
In requirements/pytorch/extra.txt the BNB is fixed to 0.41.1.
The same goes to tests: skipif will always return True. That's why these tests are skipped.

Screenshot 2023-12-27 at 7 02 15 PM

These are GPU tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh great catch. I messed upthe requirements, I'll open a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Pull requests that update a dependency file fabric lightning.fabric.Fabric feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package precision: bnb Bitsandbytes quantization ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants