Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jan 5, 2023

What does this PR do?

Fixes: https://discuss.huggingface.co/t/finetune-blip-on-customer-dataset-20893/28446
Before this PR, it was not possible to fine-tune BLIP on a custom dataset due to various reasons, mainly because the code did not supported "on-the-fly" right shifting of decoder_input_ids.
This PR also harmonizes some attributes inside BlipForQuestionAnswering --> I replaced decoder_bos_token_id by decoder_start_token_id to make it consistent with T5 etc.

For all VQA models we should (at train time):
1- make sure labels is not None
2- create decoder_input_ids based on those (make sure the padding is always on the right side)
3- Infer on the text decoder

I feel that we should probably add more tests and create a VisualQuestionAnsweringMixin in a follow up PR to make sure this is done for all VQA models (as I'd expect more VQA models to be added this year)

cc @NielsRogge @sgugger

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!
I'd rather say no to a new mixin as it would go against Transformers philosophy. We don't have this for seq2seq models for instance and just copy paste the shifting logic.

@younesbelkada
Copy link
Contributor Author

Perfect, thanks for clarifying @sgugger !

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 1212 to 1214
elif decoder_input_ids is None:
# by default use BOS token as decoder_input_ids
decoder_input_ids = torch.LongTensor([self.decoder_start_token_id]).repeat((batch_size, 1))
Copy link
Contributor

@NielsRogge NielsRogge Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's a need for this.

This is handled by the generate method automatically, which will set the decoder_input_ids appropriately.

See also BART and T5 who don't have these lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed indeed
For consistency with the original implementation, I propose to add a safety checker to check that either a label or decoder_input_ids are always passed: https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/blip_vqa.py#L46
When calling the forward pass it seems that labels (i.e. answer on the source code) is always expected.

@StevenTang1998
Copy link
Contributor

@younesbelkada @sgugger hi, thanks for contributing this code, but I found two possible bugs:

  1. the code shift labels to decoder_input_id (here) and the code shift labels when computing loss (here) should only keep one, and I prefer to keep the former one and delete the later.
  2. The BERT tokenizer has added a start token before the sequence, and the _shift_right function will add another one (pad), so it should use forced_bos_token_id like BART for generation.

@StevenTang1998
Copy link
Contributor

StevenTang1998 commented Jan 11, 2023

Moreover, I think the reduction function of CrossEntropyLoss should be set to 'mean', or you will get a loss more than tens or hundreds, which is uncommon and may affect the optimization.

@NielsRogge
Copy link
Contributor

Thanks for your valuable comments @StevenTang1998! @younesbelkada in any case it would probably be best to have verified this branch in a notebook on a toy image captioning dataset. Making the code as similar as possible to our other generative models (like T5, BART or GPT-2) would be great.

Comment on lines 1184 to 1185
if labels is None and decoder_input_ids is None:
raise ValueError("Either `decoder_input_ids` or `labels` should be passed during inference.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird that "labels" should be passed during inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how it's done on the original impelemntation apprently, check: https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/blip_vqa.py#L46 --> answer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I don't see it, the line you link to links to training mode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code you link to is in the "training" mode right? So why would we have the warning that "labels should be passed during inference"? Do you mean training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I propose a clearer error message in 896bd63

- add colab link to documentation
- reduction = mean for loss
Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing!

@younesbelkada younesbelkada merged commit 023f51f into huggingface:main Jan 18, 2023
@faiqff94
Copy link

Hi @younesbelkada, I encountered the same error as mentioned by @dxlong2000.
I cloned this repository but the error is still there.

ValueError: Expected input batch_size (0) to match target batch_size (29).

@younesbelkada
Copy link
Contributor Author

Hi @faiqff94
All the issues related to BLIP training should be resolved, if you follow what has been done in https://colab.research.google.com/drive/1lbqiSiA0sDF7JDWPeS0tccrM85LloVha?usp=sharing you should not get any issue. Can you share a reproducible handy script?

@younesbelkada younesbelkada deleted the blip-train-support branch February 20, 2023 17:48
@pribadihcr
Copy link

Hi @faiqff94 All the issues related to BLIP training should be resolved, if you follow what has been done in https://colab.research.google.com/drive/1lbqiSiA0sDF7JDWPeS0tccrM85LloVha?usp=sharing you should not get any issue. Can you share a reproducible handy script?

Hi @younesbelkada, I have tried the colab's script in local PC, but got loss: nan in epoch-0

Any advice? Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants