-
Notifications
You must be signed in to change notification settings - Fork 33.6k
[InstructBlip] Add accelerate support for instructblip
#24488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
7ee03df
838db9b
88ccd18
802b536
1f4fd81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -281,6 +281,8 @@ class InstructBlipPreTrainedModel(PreTrainedModel): | |||
| r"language_model.decoder.embed_tokens.weight", | ||||
| r"language_model.lm_head.weight", | ||||
| ] | ||||
| _keep_in_fp32_modules = ["wo"] | ||||
| _no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"] | ||||
|
|
||||
| # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip | ||||
| def _init_weights(self, module): | ||||
|
|
@@ -1264,10 +1266,13 @@ def __init__(self, config: InstructBlipConfig): | |||
| self.qformer = InstructBlipQFormerModel(config.qformer_config) | ||||
|
|
||||
| self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) | ||||
|
|
||||
| if config.use_decoder_only_language_model: | ||||
| language_model = AutoModelForCausalLM.from_config(config.text_config) | ||||
| self._no_split_modules.append("LlamaDecoderLayer") | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, should
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think transformers/src/transformers/modeling_utils.py Line 2814 in 195a9e5
language_model. What do you think?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand your comment. The code loads any model from the Hub (granted it should be llama models in the actual checkpoints) but you add the specific Llama no split blocks. This will stop working if someone adds an instructblip-2 model that loads a different language model.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A cleaner implementation maybe would be that in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you think that's better, happy to change the PR to add these changes
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see yes, sorry I misunderstood your comment, will update that now |
||||
| else: | ||||
| language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) | ||||
| self._no_split_modules.append("T5Block") | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same there. |
||||
|
|
||||
| self.language_model = language_model | ||||
|
|
||||
|
|
@@ -1422,7 +1427,7 @@ def forward( | |||
|
|
||||
| if attention_mask is None: | ||||
| attention_mask = torch.ones_like(input_ids) | ||||
| attention_mask = torch.cat([language_model_attention_mask, attention_mask], dim=1) | ||||
| attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1) | ||||
|
|
||||
| if self.config.use_decoder_only_language_model: | ||||
| outputs = self.language_model( | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last comment: this is the same issue here (this comes from T5 if I'm not mistaken). This key should also be build dynamically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, just updated it!