Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,34 @@ def compute_transition_beam_scores(

return transition_scores

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)

# Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the
# tests, through, hence this ad hoc exception)
if "transfoxl" in str(self).lower():
model_kwargs.pop("attention_mask", None)

unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1119,6 +1147,9 @@ def generate(
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())

# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
Expand Down
21 changes: 20 additions & 1 deletion tests/generation/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self):

# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)

def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
Expand Down Expand Up @@ -2699,3 +2699,22 @@ def test_constrained_beam_search_mixin_type_checks(self):

with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])

def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")

encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaises(ValueError):
model.generate(
input_ids,
do_samples=True,
)

# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaises(ValueError):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)