Skip to content

Conversation

@Andrechang
Copy link
Contributor

What does this PR do?

Allows to use decoder_inputs_embeds for model.generate in VisionEncoderDecoderModel

Who can review?

Vision Model
@amyeroberts

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 16, 2023

The documentation is not available anymore as the PR was closed or merged.

@amyeroberts
Copy link
Contributor

cc @gante

Comment on lines +2244 to +2247
#add next_tokens to inputs_embeds if using embeds to generate
if model_kwargs.get("decoder_inputs_embeds") is not None:
next_tokens_embed = self.decoder.get_input_embeddings()(next_tokens) * self.decoder.model.decoder.embed_scale
model_kwargs["decoder_inputs_embeds"] = torch.cat([model_kwargs["decoder_inputs_embeds"], next_tokens_embed.unsqueeze(1)], dim=-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will not accept these changes to the body of generate :)

There are multiple reasons for this decision:
a) A similar proposal is being tracked here. It will only be added for consideration after it raises sufficient interest, as described in the link;
b) We want to avoid adding more logic to generate itself unless it is a widely requested feature or it can be added as part of the model itself (e.g. in prepare_inputs_for_generation) / in a self-contained class (like the LogitsProcessors);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for the review

if kwargs.get("decoder_inputs_embeds") is not None:
decoder_inputs["input_ids"] = None
decoder_inputs_embeds = self.decoder.prepare_inputs_for_generation(kwargs.get("decoder_inputs_embeds"), past_key_values=past_key_values)
decoder_inputs_embeds = decoder_inputs_embeds["input_ids"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not follow the reference implementation, see here for an example

@Andrechang Andrechang closed this Feb 17, 2023
@YiandLi
Copy link

YiandLi commented Jun 18, 2023

And how about the EncoderDecoderModel like T5?

I tried to replace the prepare_inputs_for_generation method only guided by #6535, but it does not work ....



class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
    
    def prepare_inputs_for_generation(self,
                                      input_ids,
                                      past_key_values=None,
                                      attention_mask=None,
                                      head_mask=None,
                                      decoder_head_mask=None,
                                      cross_attn_head_mask=None,
                                      use_cache=None,
                                      encoder_outputs=None,
                                      **kwargs):
        res = super().prepare_inputs_for_generation(input_ids,
                                                    past_key_values,
                                                    attention_mask,
                                                    head_mask,
                                                    decoder_head_mask,
                                                    cross_attn_head_mask,
                                                    use_cache,
                                                    encoder_outputs,
                                                    **kwargs)
        # maybe another solution :https://github.com/huggingface/transformers/pull/21671
        
        # add decoder embeddings and mask
        if "decoder_inputs_embeds" in kwargs.keys():
            res["decoder_inputs_embeds"] = kwargs["decoder_inputs_embeds"]
        if "decoder_attention_mask" in kwargs.keys():
            res["decoder_attention_mask"] = kwargs["decoder_attention_mask"]
        
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if past_key_values is None:
            del res["decoder_input_ids"]
        else:
            # only last token for inputs_ids if past is defined in kwargs
            res['decoder_input_ids'] = res['decoder_input_ids'][:, -1].unsqueeze(-1)
            del res["decoder_inputs_embeds"]
        
        return res

@gante
Copy link
Contributor

gante commented Jun 19, 2023

Hey @YiandLi 👋

My suggestion would be to open a separate issue for the support of a decoder_input_embeds input, like #6535, so the issue becomes clear and visible to everyone. Like in #6535, I'd be happy to a) share a temporary solution b) push a permanent solution if the issue acquires sufficient traction.

Normally, I would not provide support for custom tasks, as my bandwidth is very limited, but according to this closed PR you are not the first person asking the question :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants