Skip to content

[InstructBlip] Add accelerate support for instructblip#24488

Merged
younesbelkada merged 5 commits into
huggingface:mainfrom
younesbelkada:add-instruct-blip-4bit
Jun 26, 2023
Merged

[InstructBlip] Add accelerate support for instructblip#24488
younesbelkada merged 5 commits into
huggingface:mainfrom
younesbelkada:add-instruct-blip-4bit

Conversation

@younesbelkada

@younesbelkada younesbelkada commented Jun 26, 2023

Copy link
Copy Markdown
Contributor

What does this PR do?

As per title, let's make users benefit from 8bit / 4bit loading of instructblip models

cc @amyeroberts @sgugger @NielsRogge

all accelerate tests pass for this model

As a side note, as instruct blip relies on flan-t5 as backbone for some models, therefore it is important to add

_keep_in_fp32_modules = ["wo"]

To ensure inference stability in fp16 / int8 / fp4

@younesbelkada younesbelkada mentioned this pull request Jun 26, 2023
5 tasks
@younesbelkada younesbelkada requested review from amyeroberts and sgugger and removed request for amyeroberts June 26, 2023 09:54
@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Jun 26, 2023

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@NielsRogge

Copy link
Copy Markdown
Collaborator

Thank you :) could you add an integration test?

@younesbelkada

Copy link
Copy Markdown
Contributor Author

Hey @NielsRogge !
Let's maybe add it together with: https://github.com/huggingface/transformers/pull/24490/files#r1242095062 so we should probably merge this first :D

Comment on lines +285 to +290
_no_split_modules = [
"InstructBlipAttention",
"T5Block",
"OPTDecoderLayer",
"InstructBlipQFormerMultiHeadAttention",
]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be adapted dynamically depending on the models that are actually loaded, to be more future proof.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, makes sense

@younesbelkada younesbelkada Jun 26, 2023

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, from what I can see and as a side note, InstructBlip relies on T5 or Llama only, so I did a mistake actually putting OPTDecoderLayer. This is also fixed

@younesbelkada younesbelkada requested a review from sgugger June 26, 2023 12:25

if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
self._no_split_modules.append("LlamaDecoderLayer")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, should language_model._no_split_modules here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think language_model._no_split_modules should already contain LlamaDecoderLayer, the purpose of adding it in self._no_split_modules is that in from_pretrained we never look at all the child modules that contain that attribute: https://github.com/huggingface/transformers/blob/195a9e5bdb1faa58cd58b47a23a47734d2b90d8c/src/transformers/modeling_utils.py#L2807C13-L2807C29 | then gets passed to accelerate this way:

kwargs = {"no_split_module_classes": no_split_modules}
right? So not sure we should add it to language_model. What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand your comment. The code loads any model from the Hub (granted it should be llama models in the actual checkpoints) but you add the specific Llama no split blocks. This will stop working if someone adds an instructblip-2 model that loads a different language model.

@younesbelkada younesbelkada Jun 26, 2023

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A cleaner implementation maybe would be that in post_init we should add an utility method to look at all child modules that contain _no_split_modules and dynamically append them to self._no_split_modules

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think that's better, happy to change the PR to add these changes

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see yes, sorry I misunderstood your comment, will update that now

self._no_split_modules.append("LlamaDecoderLayer")
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
self._no_split_modules.append("T5Block")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same there.

@younesbelkada younesbelkada requested a review from sgugger June 26, 2023 12:42
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_keep_in_fp32_modules = ["wo"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last comment: this is the same issue here (this comes from T5 if I'm not mistaken). This key should also be build dynamically.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, just updated it!

@younesbelkada younesbelkada requested a review from sgugger June 26, 2023 14:22
@younesbelkada younesbelkada merged commit 9895670 into huggingface:main Jun 26, 2023
@younesbelkada younesbelkada deleted the add-instruct-blip-4bit branch June 26, 2023 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants