Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA #29587

Merged
merged 27 commits into from
Mar 13, 2024

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Mar 11, 2024

What does this PR do?

  1. bitsandbytes has now added quant_storage option for 4-bit parameters which is required for FSDP+QLoRA support. This PR adds that option and fixes the corresponding parameter calculation.
  2. Enables loading the model layer by layer, quantising it on GPU and then transferring it to CPU so that FSDP or DeepSpeed can later shard it and put it on the GPUs. This reduces the GPU memory required. For example, a 70B model when loaded in 4-bit on each GPU without sharding would require 35Gb per GPU. However, if we load the model on CPU and shard it on 2 GPUs, then the same quantized model would require 35/2=17.5 GB per GPU which now can fit on 24GB GPUs.
  3. Dispatch needs to be disabled else it will try to put the quantized weights which are on CPU to GPu before sharding.
  4. Disable zero.init when using DeepSpeed with QLoRA.
  5. When using FSDP with PEFT LoRA, the auto wrap policy needs to be updated to additionally separately wrap LoRA trainable layers. When using FSDP with QLoRA, the mixed precision policy needs to be updated to use the quantization storage data type.

This PR should be merged after Accelerate PR huggingface/accelerate#2544.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

1. Update minimum accelerate version to 0.26.0
2. Clean the trainer wrt accelerate version checks
3. FSDP refactor and test for fsdp config
4. use `itemsize` instead of `dtype2bytes` dict
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge work @pacman100 ! 🚀
Overall it looks great on quantization end, I left few nits - i've also left an open question about accelerate min version, perhaps just a warning for users should suffice as upgrading accleerate to 0.26.0 might be too brutal for users

setup.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/utils/import_utils.py Outdated Show resolved Hide resolved
Co-Authored-By: Younes Belkada <[email protected]>
@pacman100
Copy link
Contributor Author

@younesbelkada, addressed the comments in the latest commit

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge work @pacman100 ! thanks very much for this ! 🚀

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Looks great! Added a small suggestion to simplify things a bit

src/transformers/training_args.py Outdated Show resolved Hide resolved
tests/fsdp/test_fsdp.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work - thanks for adding this!

Just two small comments / questions

model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
hf_quantizer.create_quantized_param(model, param, key, "cpu", state_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I've understood - we don't need this anymore as we move creating quantized params when loading in the state dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_quantizer is None, i.e., not quantized will always be False as the conditional logic if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized has this and as such this inner conditional is no longer required. The quantized parameters are initilaized in the else logic on line 3950

@@ -1958,7 +1973,8 @@ def _get_resized_lm_head(
if new_num_tokens is None:
return old_lm_head

if is_deepspeed_zero3_enabled():
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason we couldn't have a is_quantized property of the model, which is by default False? Having to constantly define is_quantized within the methods isn't ideal, as it requires updating in many different places if the criteria change

Copy link
Contributor Author

@pacman100 pacman100 Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have property hf_quantizer for the model, storing is_quantized would duplicate the same information. Earlier, I was directly using self.hf_quantizer is None in checks but there were suggestions above to improve readability using is_quantized, and as such, I made the changes accordingly.

Co-Authored-By: Zach Mueller <[email protected]>
@pacman100 pacman100 merged commit 350c5d1 into main Mar 13, 2024
21 checks passed
@pacman100 pacman100 deleted the smangrul/fsdp-qlora-support branch March 13, 2024 16:33
@Titus-von-Koeller
Copy link
Contributor

Really amazing work @pacman100 Thanks so much for this! ❤️

itazap pushed a commit that referenced this pull request May 14, 2024
* fsdp+qlora related changes

* fixes

* Update quantization_config.py

* support fsdp+qlora and dsz3+qlora

* Update quantization_config.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* handle fsdp+qlora and dsz3+qlora correctly while model loading

* fix param count

* quality

* fsdp related changes

* fsdp changes only when using LoRA/QLoRA

* add accelerate version check

* refactor, update min accelerate version and add tests

1. Update minimum accelerate version to 0.26.0
2. Clean the trainer wrt accelerate version checks
3. FSDP refactor and test for fsdp config
4. use `itemsize` instead of `dtype2bytes` dict

* fix test

* Address comments

Co-Authored-By: Younes Belkada <[email protected]>

* fix the conditional flag

* fix conditional flag

* address comments

Co-Authored-By: Zach Mueller <[email protected]>

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Zach Mueller <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants