From 1f1ee4685bd4ae22aa7d4636d345aac2ef8ed94c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 1 Dec 2022 17:22:37 +0000 Subject: [PATCH 01/17] add v1 with tests --- src/transformers/feature_extraction_utils.py | 116 +++++++++++++++++-- src/transformers/utils/__init__.py | 1 + src/transformers/utils/generic.py | 13 +++ tests/test_feature_extraction_common.py | 44 ++++++- 4 files changed, 163 insertions(+), 11 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index ec99c0e7c433..9c363e38944b 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -40,6 +40,7 @@ is_tf_available, is_torch_available, is_torch_device, + is_torch_dtype, logging, torch_required, ) @@ -175,25 +176,120 @@ def as_tensor(value): return self @torch_required - # Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchFeature - def to(self, device: Union[str, "torch.device"]) -> "BatchFeature": + def _check_if_dtype(self, dtype): + import torch + + if not is_torch_dtype(dtype): + if isinstance(dtype, str): + if hasattr(torch, dtype): + dtype = getattr(torch, dtype) + # needs a safety checker if you pass "cuda" + if is_torch_dtype(dtype): + return dtype + return None + else: + return dtype + + @torch_required + def _parse_to_args(self, *args, **kwargs) -> Tuple["torch.device", "torch.dtype"]: + r""" + Parse the arguments to return a tuple containing the corresponding `device` and/or `dtype` """ - Send all values to device by calling `v.to(device)` (PyTorch only). + import torch + + # Case 1: no arguments + if len(args) == 0 and kwargs is None: + raise ValueError("Use at least one argument when calling `.to`") + # Case 2: non-keyword arguments AND keyword arguments + elif len(args) >= 0 and len(set(kwargs)) > 0: + # Too many arguments + if len(args) + len(kwargs.keys()) > 2: + raise ValueError( + "Too many arguments for `.to` function. Supported arguments are `'device'` and `'dtype'` and you" + f" provided {args} and {kwargs}" + ) + + # Check correct kwargs + if len(args) == 0 and (not (("device" in set(kwargs)) or ("dtype" in set(kwargs)))): + raise ValueError( + "Please use correct keyword arguments for `.to`. Supported arguments are `'device'` and `'dtype'`" + ) + + # If an argument combined with a keyword argument `e.g. x.to(0, dtype=torch.float16)` + if len(args) == 1: + if "dtype" not in set(kwargs): + raise ValueError( + "Please pass `dtype=dtype` when calling `.to` together with a non-keyword argument for" + " `device`." + ) + device = args[0] + dtype = kwargs.pop("dtype", None) + # else, `e.g. x.to(device=0, dtype=torch.float16)` + else: + device = kwargs.pop("device", None) + dtype = kwargs.pop("dtype", None) + + if not is_torch_dtype(dtype): + if isinstance(dtype, str): + if hasattr(torch, dtype): + dtype = getattr(torch, dtype) + pass + raise ValueError(f"An unvalid `dtype` has been passed. {dtype} is not a supported `dtype`") + + elif len(args) >= 2: + raise ValueError( + "Too many non-keyword arguments provided to `.to` function. Please pass keywords arguments" + " `.to(device=device, dtype=dtype)` when using multiple arguments" + ) + elif len(args) == 1 and len(set(kwargs)) == 0: + # it either a `device` or a `dtype` + # First we check if it's a `dtype` + dtype = self._check_if_dtype(args[0]) + device = None + if dtype is None: + device = args[0] + else: + raise ValueError(f"Invalid arguments passed to `to` function. You have provided {args} and {kwargs}") + + return device, dtype + + @torch_required + def to(self, *args, **kwargs) -> "BatchFeature": + """ + Send all values to device by calling `v.to(device)` (PyTorch only). Or cast the values to the indicated dtype Args: device (`str` or `torch.device`): The device to put the tensors on. + dtype (`str` or `torch.dtype`): The dtype to cast the tensors on. Returns: [`BatchFeature`]: The same instance after modification. """ + import torch - # This check catches things like APEX blindly calling "to" on all inputs to a module - # Otherwise it passes the casts down and casts the LongTensor containing the token idxs - # into a HalfTensor - if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) for k, v in self.data.items()} - else: - logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.") + device, dtype = self._parse_to_args(*args, **kwargs) + + if device is not None: + # This check catches things like APEX blindly calling "to" on all inputs to a module + # Otherwise it passes the casts down and casts the LongTensor containing the token idxs + # into a HalfTensor + if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): + self.data = {k: v.to(device=device) for k, v in self.data.items()} + else: + logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.") + if dtype is not None: + casted_data = {} + if dtype not in [torch.float16, torch.bfloat16, torch.float, torch.double, torch.half]: + logger.warning( + f"Attempting to cast a BatchFeature to type {str(dtype)}. This is not supported and you will" + " encounter unexpected behavior." + ) + for k, v in self.data.items(): + if torch.is_floating_point(v): + casted_data[k] = v.to(dtype) + else: + casted_data[k] = v + self.data = casted_data return self diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2ca5bf96f715..53220c3fe541 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -47,6 +47,7 @@ is_tensor, is_tf_tensor, is_torch_device, + is_torch_dtype, is_torch_tensor, reshape, squeeze, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index b2725b314825..321e3ac29bc8 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -123,6 +123,19 @@ def is_torch_device(x): return False if not is_torch_available() else _is_torch_device(x) +def _is_torch_dtype(x): + import torch + + return isinstance(x, torch.dtype) + + +def is_torch_dtype(x): + """ + Tests if `x` is a torch device or not. Safe to call even if torch is not installed. + """ + return False if not is_torch_available() else _is_torch_dtype(x) + + def _is_tensorflow(x): import tensorflow as tf diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 7b7c33a9642c..bb84b64d791a 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -25,7 +25,15 @@ from huggingface_hub import HfFolder, delete_repo, set_access_token from requests.exceptions import HTTPError from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor -from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test +from transformers.testing_utils import ( + TOKEN, + USER, + check_json_file_has_correct_format, + get_tests_dir, + is_staging_test, + require_torch, + require_vision, +) from transformers.utils import is_torch_available, is_vision_available @@ -164,6 +172,40 @@ def test_init_without_params(self): feat_extract = self.feature_extraction_class() self.assertIsNotNone(feat_extract) + @require_torch + @require_vision + def test_cast_dtype_device(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + + encoding = feature_extractor(image_inputs, return_tensors="pt") + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float32) + + encoding = feature_extractor(image_inputs, return_tensors="pt").to(device="cpu", dtype=torch.float16) + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float16) + + encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16) + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float16) + + encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu") + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float32) + + encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", dtype=torch.bfloat16) + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16) + + with self.assertRaises(ValueError): + _ = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) + + encoding = feature_extractor(image_inputs, return_tensors="pt").to("float16") + self.assertEqual(encoding.pixel_values.dtype, torch.float16) + class FeatureExtractorUtilTester(unittest.TestCase): def test_cached_files_are_used_when_internet_is_down(self): From a8053dd1a3ba578e6d57db2046fd69b7e6512ec6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 1 Dec 2022 17:47:01 +0000 Subject: [PATCH 02/17] add checker --- tests/test_feature_extraction_common.py | 38 +++++++++++++------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index bb84b64d791a..71772759a65d 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -181,30 +181,32 @@ def test_cast_dtype_device(self): image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) encoding = feature_extractor(image_inputs, return_tensors="pt") - self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) - self.assertEqual(encoding.pixel_values.dtype, torch.float32) + # for layoutLM compatiblity + if not isinstance(encoding, list): + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float32) - encoding = feature_extractor(image_inputs, return_tensors="pt").to(device="cpu", dtype=torch.float16) - self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) - self.assertEqual(encoding.pixel_values.dtype, torch.float16) + encoding = feature_extractor(image_inputs, return_tensors="pt").to(device="cpu", dtype=torch.float16) + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float16) - encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16) - self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) - self.assertEqual(encoding.pixel_values.dtype, torch.float16) + encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16) + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float16) - encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu") - self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) - self.assertEqual(encoding.pixel_values.dtype, torch.float32) + encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu") + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float32) - encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", dtype=torch.bfloat16) - self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) - self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16) + encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", dtype=torch.bfloat16) + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16) - with self.assertRaises(ValueError): - _ = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) + with self.assertRaises(ValueError): + _ = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) - encoding = feature_extractor(image_inputs, return_tensors="pt").to("float16") - self.assertEqual(encoding.pixel_values.dtype, torch.float16) + encoding = feature_extractor(image_inputs, return_tensors="pt").to("float16") + self.assertEqual(encoding.pixel_values.dtype, torch.float16) class FeatureExtractorUtilTester(unittest.TestCase): From 94c2a93e5c7347ac0f69b69944041c193917070f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 1 Dec 2022 18:19:34 +0000 Subject: [PATCH 03/17] simplified version --- src/transformers/feature_extraction_utils.py | 132 ++++-------------- .../deit/test_feature_extraction_deit.py | 1 + tests/test_feature_extraction_common.py | 35 ++--- 3 files changed, 39 insertions(+), 129 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 9c363e38944b..376e7053d07d 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -48,7 +48,7 @@ if TYPE_CHECKING: if is_torch_available(): - import torch + import torch # noqa logger = logging.get_logger(__name__) @@ -139,7 +139,7 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non elif tensor_type == TensorType.PYTORCH: if not is_torch_available(): raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") - import torch + import torch # noqa def as_tensor(value): if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray): @@ -176,120 +176,40 @@ def as_tensor(value): return self @torch_required - def _check_if_dtype(self, dtype): - import torch - - if not is_torch_dtype(dtype): - if isinstance(dtype, str): - if hasattr(torch, dtype): - dtype = getattr(torch, dtype) - # needs a safety checker if you pass "cuda" - if is_torch_dtype(dtype): - return dtype - return None - else: - return dtype - - @torch_required - def _parse_to_args(self, *args, **kwargs) -> Tuple["torch.device", "torch.dtype"]: - r""" - Parse the arguments to return a tuple containing the corresponding `device` and/or `dtype` - """ - import torch - - # Case 1: no arguments - if len(args) == 0 and kwargs is None: - raise ValueError("Use at least one argument when calling `.to`") - # Case 2: non-keyword arguments AND keyword arguments - elif len(args) >= 0 and len(set(kwargs)) > 0: - # Too many arguments - if len(args) + len(kwargs.keys()) > 2: - raise ValueError( - "Too many arguments for `.to` function. Supported arguments are `'device'` and `'dtype'` and you" - f" provided {args} and {kwargs}" - ) - - # Check correct kwargs - if len(args) == 0 and (not (("device" in set(kwargs)) or ("dtype" in set(kwargs)))): - raise ValueError( - "Please use correct keyword arguments for `.to`. Supported arguments are `'device'` and `'dtype'`" - ) - - # If an argument combined with a keyword argument `e.g. x.to(0, dtype=torch.float16)` - if len(args) == 1: - if "dtype" not in set(kwargs): - raise ValueError( - "Please pass `dtype=dtype` when calling `.to` together with a non-keyword argument for" - " `device`." - ) - device = args[0] - dtype = kwargs.pop("dtype", None) - # else, `e.g. x.to(device=0, dtype=torch.float16)` - else: - device = kwargs.pop("device", None) - dtype = kwargs.pop("dtype", None) - - if not is_torch_dtype(dtype): - if isinstance(dtype, str): - if hasattr(torch, dtype): - dtype = getattr(torch, dtype) - pass - raise ValueError(f"An unvalid `dtype` has been passed. {dtype} is not a supported `dtype`") - - elif len(args) >= 2: - raise ValueError( - "Too many non-keyword arguments provided to `.to` function. Please pass keywords arguments" - " `.to(device=device, dtype=dtype)` when using multiple arguments" - ) - elif len(args) == 1 and len(set(kwargs)) == 0: - # it either a `device` or a `dtype` - # First we check if it's a `dtype` - dtype = self._check_if_dtype(args[0]) - device = None - if dtype is None: - device = args[0] - else: - raise ValueError(f"Invalid arguments passed to `to` function. You have provided {args} and {kwargs}") - - return device, dtype - - @torch_required - def to(self, *args, **kwargs) -> "BatchFeature": + def to(self, target_dtype_or_device) -> "BatchFeature": """ Send all values to device by calling `v.to(device)` (PyTorch only). Or cast the values to the indicated dtype Args: - device (`str` or `torch.device`): The device to put the tensors on. - dtype (`str` or `torch.dtype`): The dtype to cast the tensors on. + device (`str` or `torch.device`, or `torch.dtype`): The device to put the tensors on. Returns: [`BatchFeature`]: The same instance after modification. """ - import torch - - device, dtype = self._parse_to_args(*args, **kwargs) - - if device is not None: - # This check catches things like APEX blindly calling "to" on all inputs to a module - # Otherwise it passes the casts down and casts the LongTensor containing the token idxs - # into a HalfTensor - if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) for k, v in self.data.items()} - else: - logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.") - if dtype is not None: - casted_data = {} - if dtype not in [torch.float16, torch.bfloat16, torch.float, torch.double, torch.half]: - logger.warning( - f"Attempting to cast a BatchFeature to type {str(dtype)}. This is not supported and you will" - " encounter unexpected behavior." - ) + import torch # noqa + + # This check catches things like APEX blindly calling "to" on all inputs to a module + # Otherwise it passes the casts down and casts the LongTensor containing the token idxs + # into a HalfTensor + if ( + isinstance(target_dtype_or_device, str) + or is_torch_device(target_dtype_or_device) + or isinstance(target_dtype_or_device, int) + or is_torch_dtype(target_dtype_or_device) + ): + new_data = {} for k, v in self.data.items(): - if torch.is_floating_point(v): - casted_data[k] = v.to(dtype) + # for `dtype` cast it only on non-floating points + if is_torch_dtype(target_dtype_or_device) and not torch.is_floating_point(v): + new_data[k] = v else: - casted_data[k] = v - self.data = casted_data + new_data[k] = v.to(target_dtype_or_device) + self.data = new_data + else: + logger.warning( + f"Attempting to cast a BatchFeature to type {str(target_dtype_or_device)}. This is not supported." + ) + return self diff --git a/tests/models/deit/test_feature_extraction_deit.py b/tests/models/deit/test_feature_extraction_deit.py index 03b869a967fc..32b107756e7e 100644 --- a/tests/models/deit/test_feature_extraction_deit.py +++ b/tests/models/deit/test_feature_extraction_deit.py @@ -84,6 +84,7 @@ def prepare_feat_extract_dict(self): class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): feature_extraction_class = DeiTFeatureExtractor if is_vision_available() else None + test_cast_dtype = True def setUp(self): self.feature_extract_tester = DeiTFeatureExtractionTester(self) diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 71772759a65d..8c81805a3da5 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -142,6 +142,8 @@ def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify class FeatureExtractionSavingTestMixin: + test_cast_dtype = None + def test_feat_extract_to_json_string(self): feat_extract = self.feature_extraction_class(**self.feat_extract_dict) obj = json.loads(feat_extract.to_json_string()) @@ -175,39 +177,26 @@ def test_init_without_params(self): @require_torch @require_vision def test_cast_dtype_device(self): - # Initialize feature_extractor - feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) - # create random PyTorch tensors - image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) - - encoding = feature_extractor(image_inputs, return_tensors="pt") - # for layoutLM compatiblity - if not isinstance(encoding, list): - self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) - self.assertEqual(encoding.pixel_values.dtype, torch.float32) + if self.test_cast_dtype is not None: + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) - encoding = feature_extractor(image_inputs, return_tensors="pt").to(device="cpu", dtype=torch.float16) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + + encoding = feature_extractor(image_inputs, return_tensors="pt") + # for layoutLM compatiblity self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) - self.assertEqual(encoding.pixel_values.dtype, torch.float16) + self.assertEqual(encoding.pixel_values.dtype, torch.float32) encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16) self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.dtype, torch.float16) - encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu") - self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) - self.assertEqual(encoding.pixel_values.dtype, torch.float32) - - encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", dtype=torch.bfloat16) + encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16) self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16) - with self.assertRaises(ValueError): - _ = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) - - encoding = feature_extractor(image_inputs, return_tensors="pt").to("float16") - self.assertEqual(encoding.pixel_values.dtype, torch.float16) - class FeatureExtractorUtilTester(unittest.TestCase): def test_cached_files_are_used_when_internet_is_down(self): From f9045b115e3932db8effdf42dc0eda780527ba5a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 1 Dec 2022 18:20:34 +0000 Subject: [PATCH 04/17] update docstring --- src/transformers/feature_extraction_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 376e7053d07d..9eda187ce6f5 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -181,7 +181,7 @@ def to(self, target_dtype_or_device) -> "BatchFeature": Send all values to device by calling `v.to(device)` (PyTorch only). Or cast the values to the indicated dtype Args: - device (`str` or `torch.device`, or `torch.dtype`): The device to put the tensors on. + target_dtype_or_device (`str` or `torch.device`, or `torch.dtype`): The device to put the tensors on. Returns: [`BatchFeature`]: The same instance after modification. From 4af5971355bec2e43389c3818876f498cef0f247 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Dec 2022 08:53:29 +0100 Subject: [PATCH 05/17] better version --- src/transformers/feature_extraction_utils.py | 55 ++++++++++++-------- src/transformers/utils/generic.py | 5 ++ 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 9eda187ce6f5..6bb94b8b73cc 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -176,7 +176,7 @@ def as_tensor(value): return self @torch_required - def to(self, target_dtype_or_device) -> "BatchFeature": + def to(self, *args, **kwargs) -> "BatchFeature": """ Send all values to device by calling `v.to(device)` (PyTorch only). Or cast the values to the indicated dtype @@ -188,28 +188,37 @@ def to(self, target_dtype_or_device) -> "BatchFeature": """ import torch # noqa - # This check catches things like APEX blindly calling "to" on all inputs to a module - # Otherwise it passes the casts down and casts the LongTensor containing the token idxs - # into a HalfTensor - if ( - isinstance(target_dtype_or_device, str) - or is_torch_device(target_dtype_or_device) - or isinstance(target_dtype_or_device, int) - or is_torch_dtype(target_dtype_or_device) - ): - new_data = {} - for k, v in self.data.items(): - # for `dtype` cast it only on non-floating points - if is_torch_dtype(target_dtype_or_device) and not torch.is_floating_point(v): - new_data[k] = v - else: - new_data[k] = v.to(target_dtype_or_device) - self.data = new_data - else: - logger.warning( - f"Attempting to cast a BatchFeature to type {str(target_dtype_or_device)}. This is not supported." - ) - + new_data = {} + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + # check if v is a floating point + if torch.is_floating_point(v): + # cast and send to device + new_data[k] = v.to(*args, **kwargs) + else: + # just send to device + device = kwargs.pop("device", None) + # Check if the args are a device or a dtype + if device is None: + for arg in args: + if isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + break + elif is_torch_dtype(arg): + # Ignore the dtype + logger.warning( + "Attempting to cast a non-floating point element of BatchFeature to `dtype`" + f" {str(arg)}. This is not supported." + ) + else: + # it's something else + logger.warning( + f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported." + ) + # Finally send to device + if device is not None: + new_data[k] = v.to(device=device) + self.data = new_data return self diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 321e3ac29bc8..345c87477d40 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -126,6 +126,11 @@ def is_torch_device(x): def _is_torch_dtype(x): import torch + if isinstance(x, str): + if hasattr(torch, x): + return True + else: + return False return isinstance(x, torch.dtype) From e595cdf8be3eece17bc63649df210c3e8250ce65 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Dec 2022 08:56:43 +0100 Subject: [PATCH 06/17] fix docstring + change order --- src/transformers/feature_extraction_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 6bb94b8b73cc..fc5b0d8c6fdd 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -178,10 +178,15 @@ def as_tensor(value): @torch_required def to(self, *args, **kwargs) -> "BatchFeature": """ - Send all values to device by calling `v.to(device)` (PyTorch only). Or cast the values to the indicated dtype + Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This + should support casting in different `dtypes` and sending the `BatchFeature` to a + different `device`. Args: - target_dtype_or_device (`str` or `torch.device`, or `torch.dtype`): The device to put the tensors on. + args (`Tuple`): + Will be passed to the `to(...)` function of the tensors. + kwargs (`Dict`, *optional*): + Will be passed to the `to(...)` function of the tensors. Returns: [`BatchFeature`]: The same instance after modification. @@ -201,15 +206,15 @@ def to(self, *args, **kwargs) -> "BatchFeature": # Check if the args are a device or a dtype if device is None: for arg in args: - if isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): - device = arg - break - elif is_torch_dtype(arg): + if is_torch_dtype(arg): # Ignore the dtype logger.warning( "Attempting to cast a non-floating point element of BatchFeature to `dtype`" f" {str(arg)}. This is not supported." ) + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + break else: # it's something else logger.warning( From 4d05a90b2c3df36649ccbc6ed35a777f5e783d41 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Dec 2022 09:02:52 +0100 Subject: [PATCH 07/17] make style --- src/transformers/feature_extraction_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index fc5b0d8c6fdd..0bf8199337d6 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -178,9 +178,8 @@ def as_tensor(value): @torch_required def to(self, *args, **kwargs) -> "BatchFeature": """ - Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This - should support casting in different `dtypes` and sending the `BatchFeature` to a - different `device`. + Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in + different `dtypes` and sending the `BatchFeature` to a different `device`. Args: args (`Tuple`): From 4b9fddd6e93c9265463ec059ef00a53ece4eaa57 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Dec 2022 09:30:26 +0100 Subject: [PATCH 08/17] tests + change conditions --- src/transformers/feature_extraction_utils.py | 4 +++- tests/test_feature_extraction_common.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 0bf8199337d6..40f4fa7fd4f8 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -204,7 +204,9 @@ def to(self, *args, **kwargs) -> "BatchFeature": device = kwargs.pop("device", None) # Check if the args are a device or a dtype if device is None: - for arg in args: + if len(args) > 0: + # device should be always the first argument + arg = args[0] if is_torch_dtype(arg): # Ignore the dtype logger.warning( diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 8c81805a3da5..e5542f06deb2 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -193,10 +193,13 @@ def test_cast_dtype_device(self): self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.dtype, torch.float16) - encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16) + encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16) + with self.assertRaises(TypeError): + _ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu") + class FeatureExtractorUtilTester(unittest.TestCase): def test_cached_files_are_used_when_internet_is_down(self): From 09e4c494766c64ca28efe0199421c111bab4ff31 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Dec 2022 11:30:10 +0100 Subject: [PATCH 09/17] final tests --- src/transformers/feature_extraction_utils.py | 3 ++- tests/test_feature_extraction_common.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 40f4fa7fd4f8..cb954f4caacd 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -215,7 +215,6 @@ def to(self, *args, **kwargs) -> "BatchFeature": ) elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): device = arg - break else: # it's something else logger.warning( @@ -224,6 +223,8 @@ def to(self, *args, **kwargs) -> "BatchFeature": # Finally send to device if device is not None: new_data[k] = v.to(device=device) + else: + new_data[k] = v self.data = new_data return self diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index e5542f06deb2..fe8d02480644 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -200,6 +200,15 @@ def test_cast_dtype_device(self): with self.assertRaises(TypeError): _ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu") + # Try with text + image feature + encoding = feature_extractor(image_inputs, return_tensors="pt") + encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])}) + encoding = encoding.to(torch.float16) + + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float16) + self.assertEqual(encoding.input_ids.dtype, torch.long) + class FeatureExtractorUtilTester(unittest.TestCase): def test_cached_files_are_used_when_internet_is_down(self): From ba9441879c16f882dc55320773d86c30295f438f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Dec 2022 17:50:08 +0100 Subject: [PATCH 10/17] modify docstring --- src/transformers/utils/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 345c87477d40..14d8cda204fa 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -136,7 +136,7 @@ def _is_torch_dtype(x): def is_torch_dtype(x): """ - Tests if `x` is a torch device or not. Safe to call even if torch is not installed. + Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed. """ return False if not is_torch_available() else _is_torch_dtype(x) From f331dac31108be0b041749ebc8515ecf5b75f929 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 2 Dec 2022 23:46:09 +0100 Subject: [PATCH 11/17] Update src/transformers/feature_extraction_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/feature_extraction_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index cb954f4caacd..194e97f573af 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -203,8 +203,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": # just send to device device = kwargs.pop("device", None) # Check if the args are a device or a dtype - if device is None: - if len(args) > 0: + if device is None and len(args) > 0: # device should be always the first argument arg = args[0] if is_torch_dtype(arg): From af0a4c3cb1d4070fdea3c2a797f45390c586e8e5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Dec 2022 23:49:08 +0100 Subject: [PATCH 12/17] replace by `ValueError` --- src/transformers/feature_extraction_utils.py | 30 ++++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 194e97f573af..676943632147 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -204,21 +204,21 @@ def to(self, *args, **kwargs) -> "BatchFeature": device = kwargs.pop("device", None) # Check if the args are a device or a dtype if device is None and len(args) > 0: - # device should be always the first argument - arg = args[0] - if is_torch_dtype(arg): - # Ignore the dtype - logger.warning( - "Attempting to cast a non-floating point element of BatchFeature to `dtype`" - f" {str(arg)}. This is not supported." - ) - elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): - device = arg - else: - # it's something else - logger.warning( - f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported." - ) + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # Ignore the dtype + logger.warning( + "Attempting to cast a non-floating point element of BatchFeature to `dtype`" + f" {str(arg)}. This is not supported." + ) + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError( + f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported." + ) # Finally send to device if device is not None: new_data[k] = v.to(device=device) From da35051ed08070eeff691891b5eb7232a1c12dc2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Dec 2022 23:51:50 +0100 Subject: [PATCH 13/17] fix logic --- src/transformers/utils/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 14d8cda204fa..b601d1e61b1f 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -128,7 +128,7 @@ def _is_torch_dtype(x): if isinstance(x, str): if hasattr(torch, x): - return True + x = getattr(torch, x) else: return False return isinstance(x, torch.dtype) From af0417936b8472f2a2114dc61e1fec06573147a2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 5 Dec 2022 16:38:16 +0100 Subject: [PATCH 14/17] apply suggestions --- src/transformers/feature_extraction_utils.py | 40 ++++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 676943632147..01fbeee2275a 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -193,6 +193,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": import torch # noqa new_data = {} + device, _ = self._parse_to_args(*args, **kwargs) # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` for k, v in self.items(): # check if v is a floating point @@ -200,26 +201,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": # cast and send to device new_data[k] = v.to(*args, **kwargs) else: - # just send to device - device = kwargs.pop("device", None) - # Check if the args are a device or a dtype - if device is None and len(args) > 0: - # device should be always the first argument - arg = args[0] - if is_torch_dtype(arg): - # Ignore the dtype - logger.warning( - "Attempting to cast a non-floating point element of BatchFeature to `dtype`" - f" {str(arg)}. This is not supported." - ) - elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): - device = arg - else: - # it's something else - raise ValueError( - f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported." - ) - # Finally send to device + # Just send to device for int tensors if device is not None: new_data[k] = v.to(device=device) else: @@ -227,6 +209,24 @@ def to(self, *args, **kwargs) -> "BatchFeature": self.data = new_data return self + def _parse_to_args(self, *args, **kwargs): + # just send to device + device = kwargs.get("device") + dtype = kwargs.get("dtype") + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # Assign the dtype + dtype = arg + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + return device, dtype + class FeatureExtractionMixin(PushToHubMixin): """ From fee0585ae12662743de52aeafcb608293c7307ad Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 5 Dec 2022 16:40:16 +0100 Subject: [PATCH 15/17] `dtype` is not needed --- src/transformers/feature_extraction_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 01fbeee2275a..239e89a7da82 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -193,7 +193,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": import torch # noqa new_data = {} - device, _ = self._parse_to_args(*args, **kwargs) + device = self._parse_to_args(*args, **kwargs) # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` for k, v in self.items(): # check if v is a floating point @@ -212,20 +212,19 @@ def to(self, *args, **kwargs) -> "BatchFeature": def _parse_to_args(self, *args, **kwargs): # just send to device device = kwargs.get("device") - dtype = kwargs.get("dtype") # Check if the args are a device or a dtype if device is None and len(args) > 0: # device should be always the first argument arg = args[0] if is_torch_dtype(arg): - # Assign the dtype - dtype = arg + # The first argument is a dtype + pass elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): device = arg else: # it's something else raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") - return device, dtype + return device class FeatureExtractionMixin(PushToHubMixin): From 475f5ef3898d298fe8fc255556ff2a6d8fb62ebe Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 5 Dec 2022 17:34:23 +0100 Subject: [PATCH 16/17] adapt suggestions --- src/transformers/feature_extraction_utils.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 239e89a7da82..b348c447f131 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -193,24 +193,22 @@ def to(self, *args, **kwargs) -> "BatchFeature": import torch # noqa new_data = {} - device = self._parse_to_args(*args, **kwargs) + device = self._parse_args_to_device(*args, **kwargs) # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` for k, v in self.items(): # check if v is a floating point if torch.is_floating_point(v): # cast and send to device new_data[k] = v.to(*args, **kwargs) + elif device is not None: + new_data[k] = v.to(device=device) else: - # Just send to device for int tensors - if device is not None: - new_data[k] = v.to(device=device) - else: - new_data[k] = v + new_data[k] = v self.data = new_data return self - def _parse_to_args(self, *args, **kwargs): - # just send to device + def _parse_args_to_device(self, *args, **kwargs): + # Retrieve device from kwargs device = kwargs.get("device") # Check if the args are a device or a dtype if device is None and len(args) > 0: From c5b4b378958e0c0c3e9165886c864f4456a03e72 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 5 Dec 2022 18:02:05 +0100 Subject: [PATCH 17/17] remove `_parse_args_to_device` --- src/transformers/feature_extraction_utils.py | 29 ++++++++------------ 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index b348c447f131..70b8475bcad4 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -193,22 +193,6 @@ def to(self, *args, **kwargs) -> "BatchFeature": import torch # noqa new_data = {} - device = self._parse_args_to_device(*args, **kwargs) - # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` - for k, v in self.items(): - # check if v is a floating point - if torch.is_floating_point(v): - # cast and send to device - new_data[k] = v.to(*args, **kwargs) - elif device is not None: - new_data[k] = v.to(device=device) - else: - new_data[k] = v - self.data = new_data - return self - - def _parse_args_to_device(self, *args, **kwargs): - # Retrieve device from kwargs device = kwargs.get("device") # Check if the args are a device or a dtype if device is None and len(args) > 0: @@ -222,7 +206,18 @@ def _parse_args_to_device(self, *args, **kwargs): else: # it's something else raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") - return device + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + # check if v is a floating point + if torch.is_floating_point(v): + # cast and send to device + new_data[k] = v.to(*args, **kwargs) + elif device is not None: + new_data[k] = v.to(device=device) + else: + new_data[k] = v + self.data = new_data + return self class FeatureExtractionMixin(PushToHubMixin):