Skip to content

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Oct 24, 2024

What does this PR do?

Starting from huggingface/transformers#34026, the generation fails for only 3 models BlenderbotSmall, MarianMT and ProphetNet:

model_names = [
    "trl-internal-testing/tiny-random-BlenderbotSmallForConditionalGeneration",
    "trl-internal-testing/tiny-random-MarianMTModel",
    "trl-internal-testing/tiny-random-ProphetNetForConditionalGeneration",
]
import torch
from trl import AutoModelForSeq2SeqLMWithValueHead

for model_name in model_names:
    model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name)
    input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    model.generate(input_ids, decoder_input_ids=decoder_input_ids)
This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (20). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
Traceback (most recent call last):
  File "/fsx/qgallouedec/transformers/../trl/sherlock.py", line 19, in <module>
    _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/trl/trl/models/modeling_value_head.py", line 445, in generate
    return self.pretrained_model.generate(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/transformers/src/transformers/generation/utils.py", line 2217, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/fsx/qgallouedec/transformers/src/transformers/generation/utils.py", line 3208, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/transformers/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py", line 1248, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/transformers/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py", line 1133, in forward
    decoder_outputs = self.decoder(
                      ^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/transformers/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py", line 937, in forward
    positions = self.embed_positions(input_shape, past_key_values_length)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/transformers/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py", line 84, in forward
    return super().forward(positions)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 164, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/functional.py", line 2267, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index out of range in self

In the new logic:

if generation_config.max_new_tokens is None:
        generation_config.max_length = generation_config.max_length + input_ids_length

when max_new_token is not specified, the total number of tokens takes the value max_length (default 20) + input_ids_length (10 in the test) = 30.

But, for some models, the max len is 20. So the total length exceed the target length.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Fix ci ♾️ Fix test generation max_new_tokens Oct 24, 2024
@qgallouedec qgallouedec requested review from kashif and lewtun October 24, 2024 17:58
@qgallouedec qgallouedec merged commit e615974 into main Oct 24, 2024
@qgallouedec qgallouedec deleted the fix-ci branch October 24, 2024 18:20
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
* `eval_strategy="steps" if eval_dataset else "no"`

* tmp skip test

* drop `eval_strategy` in `test_sft_trainer_uncorrect_data`

* remove eval strategy

* Add parameterized test for generate method

* Revert "`eval_strategy="steps" if eval_dataset else "no"`"

This reverts commit 1e8b331.

* Revert "tmp skip test"

This reverts commit 44558f8.

* Revert "drop `eval_strategy` in `test_sft_trainer_uncorrect_data`"

This reverts commit a1ef701.

* Revert "remove eval strategy"

This reverts commit cb7fafa.

* style

* Refactor test_generate method in test_modeling_value_head.py

* `max_new_tokens=9`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants