Skip to content

Conversation

@ValeKnappich
Copy link
Contributor

What does this PR do?

@gante @sgugger

Fixes past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation. Passing past_key_values to model.generate had no effect whatsoever, since the argument was swallowed. Described in Issue #20347 (note that the validation bug was fixed in PR #20353, but the argument was still not passed along to the forward method)

The attached commit fixes the issue on my end, i.e. I now get different results when passing past_key_values to generate, as opposed to before.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 6, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will let @gante comment here as he's the specialist for generate. Make sure to run make style on your branch to fix the code quality issue.

@ValeKnappich
Copy link
Contributor Author

After doing some more testing, I noticed another issue that might or might not be a bug. Currently, it's not possible to use anything else than 1 for num_return_sequences. Here is a MWE:

import torch
from transformers import GPTNeoXForCausalLM, AutoTokenizer

# Load model
s = "NinedayWang/PolyCoder-160M"
model = GPTNeoXForCausalLM.from_pretrained(s)
tokenizer = AutoTokenizer.from_pretrained(s, pad_token="<|PAD|>")

# Create random prompt
N_TOKENS = 100
BATCH_SIZE=1
NUM_RETURN_SEQUENCES=8
pkv = torch.rand(
    (
        BATCH_SIZE,      # batch size      
        N_TOKENS,    # number of tokens
        2 * model.config.num_hidden_layers, 
        model.config.num_attention_heads, 
        model.config.hidden_size // model.config.num_attention_heads
    )
).permute([2, 0, 3, 1, 4]).split(2)

# Tokenize
enc = tokenizer("Hello world", return_tensors="pt")
enc["attention_mask"] = torch.cat((torch.ones((1, N_TOKENS)), enc["attention_mask"]), dim=1)

# Generate
print(
    tokenizer.decode(
        model.generate( 
            **enc,
            past_key_values=pkv,
            max_new_tokens=100,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            num_return_sequences=NUM_RETURN_SEQUENCES
        )[0],
        skip_special_tokens=True
    )
)

Leads to

Traceback (most recent call last):
  File "stuff/test.py", line 32, in <module>
    num_return_sequences=2
  File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/pfs/data5/home/st/st_us-052400/st_st175337/thesis/transformers/src/transformers/generation/utils.py", line 1581, in generate
    **model_kwargs,
  File "/pfs/data5/home/st/st_us-052400/st_st175337/thesis/transformers/src/transformers/generation/utils.py", line 2538, in sample
    output_hidden_states=output_hidden_states,
  File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/pfs/data5/home/st/st_us-052400/st_st175337/thesis/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 663, in forward
    return_dict=return_dict,
  File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/pfs/data5/home/st/st_us-052400/st_st175337/thesis/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 552, in forward
    output_attentions=output_attentions,
  File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/pfs/data5/home/st/st_us-052400/st_st175337/thesis/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 325, in forward
    output_attentions=output_attentions,
  File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/pfs/data5/home/st/st_us-052400/st_st175337/thesis/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 148, in forward
    key = torch.cat((past_key, key), dim=-2)
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1 but got size 2 for tensor number 1 in the list.

Is that expected behavior? I can fix it by creating multiple prompts (see below) per input, but it seems unintuitive, and I don't see anything about it in the docs. Perhaps the docs should simply mention that.

pkv = torch.rand(
    (
        BATCH_SIZE * NUM_RETURN_SEQUENCES,      # <--- expand the batch size 
        N_TOKENS,    # number of tokens
        2 * model.config.num_hidden_layers, 
        model.config.num_attention_heads, 
        model.config.hidden_size // model.config.num_attention_heads
    )
).permute([2, 0, 3, 1, 4]).split(2)

@gante
Copy link
Contributor

gante commented Dec 21, 2022

Hey @ValeKnappich 👋

Thank you for the addition, I really think we should do this for all models for a better interface. In fact, the argument should be past_key_values and not past, as mentioned in the original issue, but that's a deeper change. This PR is a quick fix for the problem, so I approve it.

As for num_return_sequences, let's open a new issue for it to avoid mixing too many things here :D

@gante gante merged commit 2da82bb into huggingface:main Dec 21, 2022
MKhalusova pushed a commit to MKhalusova/transformers that referenced this pull request Dec 28, 2022
…on (huggingface#20621)

* fix past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation

* fix formatting
silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
…on (huggingface#20621)

* fix past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation

* fix formatting
@ardywibowo
Copy link

Hi, has this issue been resolved? I tried running the code snippet above:

import torch
from transformers import GPTNeoXForCausalLM, AutoTokenizer

# Load model
s = "NinedayWang/PolyCoder-160M"
model = GPTNeoXForCausalLM.from_pretrained(s)
tokenizer = AutoTokenizer.from_pretrained(s, pad_token="<|PAD|>")

# Create random prompt
N_TOKENS = 100
BATCH_SIZE=1
NUM_RETURN_SEQUENCES=8
pkv = torch.rand(
    (
        BATCH_SIZE,      # batch size      
        N_TOKENS,    # number of tokens
        2 * model.config.num_hidden_layers, 
        model.config.num_attention_heads, 
        model.config.hidden_size // model.config.num_attention_heads
    )
).permute([2, 0, 3, 1, 4]).split(2)

# Tokenize
enc = tokenizer("Hello world", return_tensors="pt")
enc["attention_mask"] = torch.cat((torch.ones((1, N_TOKENS)), enc["attention_mask"]), dim=1)

# Generate
print(
    tokenizer.decode(
        model.generate( 
            **enc,
            past_key_values=pkv,
            max_new_tokens=100,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            num_return_sequences=NUM_RETURN_SEQUENCES
        )[0],
        skip_special_tokens=True
    )
)

and it returned with

RuntimeError: The size of tensor a (101) must match the size of tensor b (102) at non-singleton dimension 3

Is this a different error?

@gante
Copy link
Contributor

gante commented Jun 8, 2023

@ardywibowo the script I paste below works. But keep in mind that it is probably not doing what you expect: when past_key_values is passed, only the latest input token is considered (the all other previous tokens are supposed to be encoded in past_key_valies) -- in other words, "Hello" in "Hello world" is ignored when generating the next token, despite being present in the output text.

To understand why, you would have to dive into this blog post and into our generate code :)


import torch
from transformers import GPTNeoXForCausalLM, AutoTokenizer

# Load model
s = "NinedayWang/PolyCoder-160M"
model = GPTNeoXForCausalLM.from_pretrained(s)
tokenizer = AutoTokenizer.from_pretrained(s, pad_token="<|PAD|>")

# Create random prompt
N_TOKENS = 100
BATCH_SIZE=1
pkv = torch.rand(
    (
        BATCH_SIZE,      # batch size
        N_TOKENS,    # number of tokens
        2 * model.config.num_hidden_layers,
        model.config.num_attention_heads,
        model.config.hidden_size // model.config.num_attention_heads
    )
).permute([2, 0, 3, 1, 4]).split(2)

# Tokenize
enc = tokenizer("Hello world", return_tensors="pt")
enc["attention_mask"] = torch.ones((1, N_TOKENS+1))

# Generate
print(
    tokenizer.decode(
        model.generate(
            **enc,
            past_key_values=pkv,
            max_new_tokens=100,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
        )[0],
        skip_special_tokens=True
    )
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants