-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Allows to use decoder_inputs_embeds for model.generate
#21671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
cc @gante |
| #add next_tokens to inputs_embeds if using embeds to generate | ||
| if model_kwargs.get("decoder_inputs_embeds") is not None: | ||
| next_tokens_embed = self.decoder.get_input_embeddings()(next_tokens) * self.decoder.model.decoder.embed_scale | ||
| model_kwargs["decoder_inputs_embeds"] = torch.cat([model_kwargs["decoder_inputs_embeds"], next_tokens_embed.unsqueeze(1)], dim=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will not accept these changes to the body of generate :)
There are multiple reasons for this decision:
a) A similar proposal is being tracked here. It will only be added for consideration after it raises sufficient interest, as described in the link;
b) We want to avoid adding more logic to generate itself unless it is a widely requested feature or it can be added as part of the model itself (e.g. in prepare_inputs_for_generation) / in a self-contained class (like the LogitsProcessors);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for the review
| if kwargs.get("decoder_inputs_embeds") is not None: | ||
| decoder_inputs["input_ids"] = None | ||
| decoder_inputs_embeds = self.decoder.prepare_inputs_for_generation(kwargs.get("decoder_inputs_embeds"), past_key_values=past_key_values) | ||
| decoder_inputs_embeds = decoder_inputs_embeds["input_ids"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not follow the reference implementation, see here for an example
|
And how about the EncoderDecoderModel like T5? I tried to replace the |
|
Hey @YiandLi 👋 My suggestion would be to open a separate issue for the support of a Normally, I would not provide support for custom tasks, as my bandwidth is very limited, but according to this closed PR you are not the first person asking the question :) |
What does this PR do?
Allows to use
decoder_inputs_embedsformodel.generatein VisionEncoderDecoderModelWho can review?
Vision Model
@amyeroberts