-
Notifications
You must be signed in to change notification settings - Fork 31.9k
blip support for training
#21021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
blip support for training
#21021
Conversation
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
I'd rather say no to a new mixin as it would go against Transformers philosophy. We don't have this for seq2seq models for instance and just copy paste the shifting logic.
|
Perfect, thanks for clarifying @sgugger ! |
|
The documentation is not available anymore as the PR was closed or merged. |
| elif decoder_input_ids is None: | ||
| # by default use BOS token as decoder_input_ids | ||
| decoder_input_ids = torch.LongTensor([self.decoder_start_token_id]).repeat((batch_size, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure there's a need for this.
This is handled by the generate method automatically, which will set the decoder_input_ids appropriately.
See also BART and T5 who don't have these lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed indeed
For consistency with the original implementation, I propose to add a safety checker to check that either a label or decoder_input_ids are always passed: https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/blip_vqa.py#L46
When calling the forward pass it seems that labels (i.e. answer on the source code) is always expected.
|
@younesbelkada @sgugger hi, thanks for contributing this code, but I found two possible bugs:
|
|
Moreover, I think the |
|
Thanks for your valuable comments @StevenTang1998! @younesbelkada in any case it would probably be best to have verified this branch in a notebook on a toy image captioning dataset. Making the code as similar as possible to our other generative models (like T5, BART or GPT-2) would be great. |
| if labels is None and decoder_input_ids is None: | ||
| raise ValueError("Either `decoder_input_ids` or `labels` should be passed during inference.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's weird that "labels" should be passed during inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how it's done on the original impelemntation apprently, check: https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/blip_vqa.py#L46 --> answer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I don't see it, the line you link to links to training mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code you link to is in the "training" mode right? So why would we have the warning that "labels should be passed during inference"? Do you mean training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I propose a clearer error message in 896bd63
- add colab link to documentation - reduction = mean for loss
NielsRogge
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for fixing!
|
Hi @younesbelkada, I encountered the same error as mentioned by @dxlong2000. ValueError: Expected input batch_size (0) to match target batch_size (29). |
|
Hi @faiqff94 |
Hi @younesbelkada, I have tried the colab's script in local PC, but got loss: nan in epoch-0 Any advice? Thanks |
What does this PR do?
Fixes: https://discuss.huggingface.co/t/finetune-blip-on-customer-dataset-20893/28446
Before this PR, it was not possible to fine-tune BLIP on a custom dataset due to various reasons, mainly because the code did not supported "on-the-fly" right shifting of
decoder_input_ids.This PR also harmonizes some attributes inside
BlipForQuestionAnswering--> I replaceddecoder_bos_token_idbydecoder_start_token_idto make it consistent with T5 etc.For all VQA models we should (at train time):
1- make sure
labelsis not None2- create
decoder_input_idsbased on those (make sure the padding is always on the right side)3- Infer on the text decoder
I feel that we should probably add more tests and create a
VisualQuestionAnsweringMixinin a follow up PR to make sure this is done for all VQA models (as I'd expect more VQA models to be added this year)cc @NielsRogge @sgugger