From bf05c3c3a9be5b7c904f1c89885c152143779fc3 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 18 Aug 2022 10:02:05 +0100 Subject: [PATCH 1/3] Rename method to avoid clash with property --- src/transformers/image_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 437e7c568558..120d7b3c1bd2 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -131,7 +131,7 @@ def convert_rgb(self, image): return image.convert("RGB") - def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray: + def rescale_image(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray: """ Rescale a numpy image by scale amount """ @@ -163,7 +163,7 @@ def to_numpy_array(self, image, rescale=None, channel_first=True): rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale if rescale: - image = self.rescale(image.astype(np.float32), 1 / 255.0) + image = self.rescale_image(image.astype(np.float32), 1 / 255.0) if channel_first and image.ndim == 3: image = image.transpose(2, 0, 1) @@ -214,9 +214,9 @@ def normalize(self, image, mean, std, rescale=False): # type it may need rescaling. elif rescale: if isinstance(image, np.ndarray): - image = self.rescale(image.astype(np.float32), 1 / 255.0) + image = self.rescale_image(image.astype(np.float32), 1 / 255.0) elif is_torch_tensor(image): - image = self.rescale(image.float(), 1 / 255.0) + image = self.rescale_image(image.float(), 1 / 255.0) if isinstance(image, np.ndarray): if not isinstance(mean, np.ndarray): From e3548659dc072232a50a478ecdfc579e45627966 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 18 Aug 2022 10:56:21 +0100 Subject: [PATCH 2/3] Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653) --- src/transformers/generation_flax_utils.py | 24 ++++++++++++++++++- .../generation/test_generation_flax_utils.py | 21 ++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index fd26a605c48b..1c052aae7baf 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -15,9 +15,10 @@ # limitations under the License. +import inspect import warnings from functools import partial -from typing import Dict, Optional +from typing import Any, Dict, Optional import numpy as np @@ -160,6 +161,24 @@ def _adapt_logits_for_beam_search(self, logits): """ return logits + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + unused_model_args = [] + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) + if "kwargs" in model_args: + model_args |= set(inspect.signature(self.__call__).parameters) + for key, value in model_kwargs.items(): + if value is not None and key not in model_args: + unused_model_args.append(key) + + if unused_model_args: + raise ValueError( + f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" + " generate arguments will also show up in this list)" + ) + def generate( self, input_ids: jnp.ndarray, @@ -262,6 +281,9 @@ def generate( >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ```""" + # Validate model kwargs + self._validate_model_kwargs(model_kwargs.copy()) + # set init values bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id diff --git a/tests/generation/test_generation_flax_utils.py b/tests/generation/test_generation_flax_utils.py index b7b84d8db725..aabab559853b 100644 --- a/tests/generation/test_generation_flax_utils.py +++ b/tests/generation/test_generation_flax_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import random +import unittest import numpy as np @@ -26,6 +27,7 @@ import jax.numpy as jnp from jax import jit + from transformers import AutoTokenizer, FlaxAutoModelForCausalLM from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 @@ -273,3 +275,22 @@ def test_beam_search_generate_attn_mask(self): jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + +@require_flax +class FlaxGenerationIntegrationTests(unittest.TestCase): + def test_validate_generation_inputs(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert") + model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + + encoder_input_str = "Hello world" + input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids + + # typos are quickly detected (the correct argument is `do_sample`) + with self.assertRaisesRegex(ValueError, "do_samples"): + model.generate(input_ids, do_samples=True) + + # arbitrary arguments that will not be used anywhere are also not accepted + with self.assertRaisesRegex(ValueError, "foo"): + fake_model_kwargs = {"foo": "bar"} + model.generate(input_ids, **fake_model_kwargs) From b8c47ceb4cea174934f53254b9ba9c5d01dc81ee Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 18 Aug 2022 12:57:18 +0200 Subject: [PATCH 3/3] Ping `detectron2` for CircleCI tests (#18680) Co-authored-by: ydshieh --- .circleci/config.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 666505ab3b43..3e5c6aaa8858 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1003,7 +1003,10 @@ jobs: - run: pip install --upgrade pip - run: pip install .[torch,testing,vision] - run: pip install torchvision - - run: python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' + # The commit `36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0` in `detectron2` break things. + # See https://github.com/facebookresearch/detectron2/commit/36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0#comments. + # TODO: Revert this change back once the above issue is fixed. + - run: python -m pip install 'git+https://github.com/facebookresearch/detectron2.git@5aeb252b194b93dc2879b4ac34bc51a31b5aee13' - run: sudo apt install tesseract-ocr - run: pip install pytesseract - save_cache: