Skip to content

Skip DeepSpeed ZeRO Stage 3 model initialization when bnb#34395

Merged
LysandreJik merged 20 commits intohuggingface:mainfrom
eljandoubi:fix_paligemma_bnb_deepspeed
Nov 5, 2024
Merged

Skip DeepSpeed ZeRO Stage 3 model initialization when bnb#34395
LysandreJik merged 20 commits intohuggingface:mainfrom
eljandoubi:fix_paligemma_bnb_deepspeed

Conversation

@eljandoubi
Copy link
Contributor

What does this PR do?

Skip DeepSpeed ZeRO Stage 3 model initialization when it is intended to be quantized.

Fixes #34378

Models:

Integrations:

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR ! Left a few comments

Comment on lines +306 to +308
if hasattr(config, "quantization_config"):
vision_config.quantization_config = config.quantization_config
text_config.quantization_config = config.quantization_config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to have this in the model config. Otherwise, we would have to do it for every model. Also, we shouldn't need to do that as we quantize the model from the top level. Maybe we can propagate the quantization_config in from_pretrained ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantization_config is already passed to the model class via from_pretrained, but the sub-models are instantiated using from_config, which does not include it. Perhaps we can propagate the quantization information using a context manager.

Comment on lines +1533 to +1535
is_quantized = hasattr(config, "quantization_config")

if is_deepspeed_zero3_enabled() and not is_quantized:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of checking the quantization_config here as we don't really quantize the model with from_config. However, I'm not sure if there is an easier solution.Another solution would be to pass an arg in kwargs that we will pop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pass an argument to flag quantization, it would require changes in every composed model.

@SunMarc SunMarc requested a review from muellerzr October 28, 2024 02:34
@eljandoubi
Copy link
Contributor Author

eljandoubi commented Oct 29, 2024

@SunMarc @muellerzr What do you think of this solution?

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I think this solution looks quite nice. cc @ArthurZucker

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me ! Thanks for iterating and coming up with this solution !

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@LysandreJik LysandreJik merged commit d0b1d8d into huggingface:main Nov 5, 2024
@ArthurZucker ArthurZucker requested review from SunMarc and removed request for ArthurZucker November 5, 2024 10:21
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…e#34395)

* Skip DeepSpeed ZeRO Stage 3 model initialization when it is intended to be quantized.

* Propagate the quantization state using a context manager

* make fixup
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

deepspeed zero3 and bitsandbytes are not working

5 participants