Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_keep_in_fp32_modules = ["wo"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last comment: this is the same issue here (this comes from T5 if I'm not mistaken). This key should also be build dynamically.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, just updated it!

_no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"]

# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
def _init_weights(self, module):
Expand Down Expand Up @@ -1264,10 +1266,13 @@ def __init__(self, config: InstructBlipConfig):
self.qformer = InstructBlipQFormerModel(config.qformer_config)

self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)

if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
self._no_split_modules.append("LlamaDecoderLayer")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, should language_model._no_split_modules here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think language_model._no_split_modules should already contain LlamaDecoderLayer, the purpose of adding it in self._no_split_modules is that in from_pretrained we never look at all the child modules that contain that attribute: https://github.com/huggingface/transformers/blob/195a9e5bdb1faa58cd58b47a23a47734d2b90d8c/src/transformers/modeling_utils.py#L2807C13-L2807C29 | then gets passed to accelerate this way:

kwargs = {"no_split_module_classes": no_split_modules}
right? So not sure we should add it to language_model. What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand your comment. The code loads any model from the Hub (granted it should be llama models in the actual checkpoints) but you add the specific Llama no split blocks. This will stop working if someone adds an instructblip-2 model that loads a different language model.

@younesbelkada younesbelkada Jun 26, 2023

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A cleaner implementation maybe would be that in post_init we should add an utility method to look at all child modules that contain _no_split_modules and dynamically append them to self._no_split_modules

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think that's better, happy to change the PR to add these changes

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see yes, sorry I misunderstood your comment, will update that now

else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
self._no_split_modules.append("T5Block")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same there.


self.language_model = language_model

Expand Down Expand Up @@ -1422,7 +1427,7 @@ def forward(

if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_model_attention_mask, attention_mask], dim=1)
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)

if self.config.use_decoder_only_language_model:
outputs = self.language_model(
Expand Down