Skip to content

Conversation

@dimitry12
Copy link
Contributor

What does this PR do?

This PR extends #21405 by @gante to GPT-J, making it accept inputs_embeds when generating.

This is generally useful for soft-prompting but I am specifically using this with https://github.com/jmerullo/limber by @jmerullo

Importantly, I find that dummy input_ids are still required. Sample code using this feature with GPT-J:

from transformers import GPTJForCausalLM, AutoTokenizer

model_name = "hf-internal-testing/tiny-random-GPTJModel"; revision="main"
model = GPTJForCausalLM.from_pretrained(model_name, revision=revision)
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs_embeds = torch.rand((1, 144, 32,)) # 144 dummy soft-prompt token embeddings
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=torch.long).to(model.device)
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens

model.generate(filler_input_ids, inputs_embeds=inputs_embeds, max_length=300)

@dimitry12 dimitry12 changed the title Dzmitry/gptj input embeds Add inputs_embeds support when generating with GPT-J Feb 10, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 11, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the addition 👍

@gante
Copy link
Contributor

gante commented Feb 11, 2023

@dimitry12 to clarify, input_ids are not a required argument when input_embeds is passed :) Or should not be, let me know if you're getting errors in that case.

If you don't pass them, the output of .generate() should only contain the newly generated tokens.

@gante
Copy link
Contributor

gante commented Feb 11, 2023

@dimitry12 oh I see, the case when the batch size is larger than one is not handled, it is the same issue as in #21578! I'll open a PR soon that fixes it

@dimitry12
Copy link
Contributor Author

@gante thank you for the prompt review! Is merging done by HF staff?

@gante
Copy link
Contributor

gante commented Feb 11, 2023

Hey @dimitry12 yes it is, but we still need a review from our transformers master @sgugger :)

@gante gante requested a review from sgugger February 11, 2023 15:23
@dimitry12
Copy link
Contributor Author

dimitry12 commented Feb 11, 2023

@gante

@dimitry12 to clarify, input_ids are not a required argument when input_embeds is passed :) Or should not be, let me know if you're getting errors in that case.

If you don't pass them, the output of .generate() should only contain the newly generated tokens.

Yes, it fails without dummy input_ids, but to be clear, it fails differently compared to BLIP2 (#21578).

To replicate, a slightly modified test_generate_from_input_embeds_decoder_only is sufficient:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")

text = "Hello world"
input_ids = tokenizer.encode(text, return_tensors="pt")

# Same thing, but from input embeddings
inputs_embeds = model.transformer.wte(input_ids)
outputs_from_embeds = model.generate(inputs_embeds=inputs_embeds)

results in:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>                                                                                      │
│                                                                                                  │
│    8                                                                                             │
│    9 # Same thing, but from input embeddings                                                     │
│   10 inputs_embeds = model.transformer.wte(input_ids)                                            │
│ ❱ 11 outputs_from_embeds = model.generate(inputs_embeds=inputs_embeds)                           │
│   12                                                                                             │
│                                                                                                  │
│ /home/dzmitry/miniconda3/envs/limber.py310/lib/python3.10/site-packages/torch/autograd/grad_mode │
│ .py:27 in decorate_context                                                                       │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /home/dzmitry/miniconda3/envs/limber.py310/lib/python3.10/site-packages/transformers/generation/ │
│ utils.py:1386 in generate                                                                        │
│                                                                                                  │
│   1383 │   │   │   │   )                                                                         │
│   1384 │   │   │                                                                                 │
│   1385 │   │   │   # 11. run greedy search                                                       │
│ ❱ 1386 │   │   │   return self.greedy_search(                                                    │
│   1387 │   │   │   │   input_ids,                                                                │
│   1388 │   │   │   │   logits_processor=logits_processor,                                        │
│   1389 │   │   │   │   stopping_criteria=stopping_criteria,                                      │
│                                                                                                  │
│ /home/dzmitry/miniconda3/envs/limber.py310/lib/python3.10/site-packages/transformers/generation/ │
│ utils.py:2181 in greedy_search                                                                   │
│                                                                                                  │
│   2178 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  │
│   2179 │   │   │                                                                                 │
│   2180 │   │   │   # forward pass to get next token                                              │
│ ❱ 2181 │   │   │   outputs = self(                                                               │
│   2182 │   │   │   │   **model_inputs,                                                           │
│   2183 │   │   │   │   return_dict=True,                                                         │
│   2184 │   │   │   │   output_attentions=output_attentions,                                      │
│                                                                                                  │
│ /home/dzmitry/miniconda3/envs/limber.py310/lib/python3.10/site-packages/torch/nn/modules/module. │
│ py:1194 in _call_impl                                                                            │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/dzmitry/miniconda3/envs/limber.py310/lib/python3.10/site-packages/transformers/models/gpt2 │
│ /modeling_gpt2.py:1073 in forward                                                                │
│                                                                                                  │
│   1070 │   │   """                                                                               │
│   1071 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return  │
│   1072 │   │                                                                                     │
│ ❱ 1073 │   │   transformer_outputs = self.transformer(                                           │
│   1074 │   │   │   input_ids,                                                                    │
│   1075 │   │   │   past_key_values=past_key_values,                                              │
│   1076 │   │   │   attention_mask=attention_mask,                                                │
│                                                                                                  │
│ /home/dzmitry/miniconda3/envs/limber.py310/lib/python3.10/site-packages/torch/nn/modules/module. │
│ py:1194 in _call_impl                                                                            │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/dzmitry/miniconda3/envs/limber.py310/lib/python3.10/site-packages/transformers/models/gpt2 │
│ /modeling_gpt2.py:793 in forward                                                                 │
│                                                                                                  │
│    790 │   │   if token_type_ids is not None:                                                    │
│    791 │   │   │   token_type_ids = token_type_ids.view(-1, input_shape[-1])                     │
│    792 │   │   if position_ids is not None:                                                      │
│ ❱  793 │   │   │   position_ids = position_ids.view(-1, input_shape[-1])                         │
│    794 │   │                                                                                     │
│    795 │   │   if past_key_values is None:                                                       │
│    796 │   │   │   past_length = 0                                                               │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: shape '[-1, 5]' is invalid for input of size 1

@gante I can open a draft PR with an updated failing test, and see if I can figure it out or if your planned fix for #21578 will also fix it. Please advise what the best process is, definitely willing to help here.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@gante
Copy link
Contributor

gante commented Feb 13, 2023

(For context, the fix to the issue above is in #21580. Merging :D )

@gante gante merged commit 93ed89b into huggingface:main Feb 13, 2023
@dimitry12 dimitry12 deleted the dzmitry/gptj_input_embeds branch February 13, 2023 15:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants