Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Feb 11, 2023

What does this PR do?

Fixes #21578 and addresses concerns in #21575

Context

Support for .generate() from inputs_embeds with selected decoder-only models was added recently (#21405). This feature enables a .generate(inputs_embeds=inputs_embeds) call, i.e. without a input_ids.

This specific call strategy, under the hood, implies that .generate() is in charge of creating a) an input_ids for later use (to concatenate the generated tokens) and b) the corresponding attention_mask. This automated creation was not working properly for batch_size>1.

Changes

The changes in the PR respect the following desiderata (which required moving a few things around):

  1. The attention_mask can be automatically inferred (with all ones) regardless of the shape of inputs_embeds;
  2. When inputs_embeds is passed and input_ids is not, the automated input_ids has a sequence length of 1. This is particularly relevant for BLIP, as we don't want input_ids to start with the embeddings' sequence length.

This PR also adds/enhances tests, to ensure we don't regress on this capability.

⚠️ if approved, I will make the corresponding TF changes before merging.

@gante gante requested a review from sgugger February 11, 2023 16:51
@gante
Copy link
Contributor Author

gante commented Feb 11, 2023

cc @dimitry12 this fixes the error you reported in #21575 (GPT2 was throwing the same error as GPTJ, and gets fixed here)

cc @NielsRogge please don't forget to add batched tests on models with generation capabilities, generate+batching is surprisingly tricky🙏

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 11, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante gante mentioned this pull request Feb 13, 2023
4 tasks
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work on this!

@sgugger
Copy link
Collaborator

sgugger commented Feb 13, 2023

But it does seem like it breaks a lot of tests 😅

@gante
Copy link
Contributor Author

gante commented Feb 13, 2023

(Merging -- the failing test is a known failing test)

@gante gante merged commit fa4bdb0 into huggingface:main Feb 13, 2023
@gante gante deleted the batched_default_input_ids_decoder_only branch February 13, 2023 17:04
@gante
Copy link
Contributor Author

gante commented Feb 13, 2023

@dimitry12 lmk if you see the error when using GPT-J :)

@dimitry12
Copy link
Contributor

@dimitry12 lmk if you see the error when using GPT-J :)

@gante GPT-J generation without dummy input_ids using only inputs_embeds works without errors now. Thank you!

@ArthurZucker ArthurZucker mentioned this pull request Feb 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError when running batched inference for Salesforce/blip2-opt-2.7b

4 participants