diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index fd26a605c48b..1c052aae7baf 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -15,9 +15,10 @@ # limitations under the License. +import inspect import warnings from functools import partial -from typing import Dict, Optional +from typing import Any, Dict, Optional import numpy as np @@ -160,6 +161,24 @@ def _adapt_logits_for_beam_search(self, logits): """ return logits + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + unused_model_args = [] + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) + if "kwargs" in model_args: + model_args |= set(inspect.signature(self.__call__).parameters) + for key, value in model_kwargs.items(): + if value is not None and key not in model_args: + unused_model_args.append(key) + + if unused_model_args: + raise ValueError( + f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" + " generate arguments will also show up in this list)" + ) + def generate( self, input_ids: jnp.ndarray, @@ -262,6 +281,9 @@ def generate( >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ```""" + # Validate model kwargs + self._validate_model_kwargs(model_kwargs.copy()) + # set init values bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id diff --git a/tests/generation/test_generation_flax_utils.py b/tests/generation/test_generation_flax_utils.py index b7b84d8db725..aabab559853b 100644 --- a/tests/generation/test_generation_flax_utils.py +++ b/tests/generation/test_generation_flax_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import random +import unittest import numpy as np @@ -26,6 +27,7 @@ import jax.numpy as jnp from jax import jit + from transformers import AutoTokenizer, FlaxAutoModelForCausalLM from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 @@ -273,3 +275,22 @@ def test_beam_search_generate_attn_mask(self): jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + +@require_flax +class FlaxGenerationIntegrationTests(unittest.TestCase): + def test_validate_generation_inputs(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert") + model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + + encoder_input_str = "Hello world" + input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids + + # typos are quickly detected (the correct argument is `do_sample`) + with self.assertRaisesRegex(ValueError, "do_samples"): + model.generate(input_ids, do_samples=True) + + # arbitrary arguments that will not be used anywhere are also not accepted + with self.assertRaisesRegex(ValueError, "foo"): + fake_model_kwargs = {"foo": "bar"} + model.generate(input_ids, **fake_model_kwargs)