Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/transformers/generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# limitations under the License.


import inspect
import warnings
from functools import partial
from typing import Dict, Optional
from typing import Any, Dict, Optional

import numpy as np

Expand Down Expand Up @@ -160,6 +161,24 @@ def _adapt_logits_for_beam_search(self, logits):
"""
return logits

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.__call__).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)

def generate(
self,
input_ids: jnp.ndarray,
Expand Down Expand Up @@ -262,6 +281,9 @@ def generate(
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())

# set init values
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
Expand Down
21 changes: 21 additions & 0 deletions tests/generation/test_generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import random
import unittest

import numpy as np

Expand All @@ -26,6 +27,7 @@

import jax.numpy as jnp
from jax import jit
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
Expand Down Expand Up @@ -273,3 +275,22 @@ def test_beam_search_generate_attn_mask(self):
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences

self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())


@require_flax
class FlaxGenerationIntegrationTests(unittest.TestCase):
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")

encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids

# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)

# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)