diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index ec99c0e7c433..70b8475bcad4 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -40,6 +40,7 @@ is_tf_available, is_torch_available, is_torch_device, + is_torch_dtype, logging, torch_required, ) @@ -47,7 +48,7 @@ if TYPE_CHECKING: if is_torch_available(): - import torch + import torch # noqa logger = logging.get_logger(__name__) @@ -138,7 +139,7 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non elif tensor_type == TensorType.PYTORCH: if not is_torch_available(): raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") - import torch + import torch # noqa def as_tensor(value): if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray): @@ -175,25 +176,47 @@ def as_tensor(value): return self @torch_required - # Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchFeature - def to(self, device: Union[str, "torch.device"]) -> "BatchFeature": + def to(self, *args, **kwargs) -> "BatchFeature": """ - Send all values to device by calling `v.to(device)` (PyTorch only). + Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in + different `dtypes` and sending the `BatchFeature` to a different `device`. Args: - device (`str` or `torch.device`): The device to put the tensors on. + args (`Tuple`): + Will be passed to the `to(...)` function of the tensors. + kwargs (`Dict`, *optional*): + Will be passed to the `to(...)` function of the tensors. Returns: [`BatchFeature`]: The same instance after modification. """ - - # This check catches things like APEX blindly calling "to" on all inputs to a module - # Otherwise it passes the casts down and casts the LongTensor containing the token idxs - # into a HalfTensor - if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) for k, v in self.data.items()} - else: - logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.") + import torch # noqa + + new_data = {} + device = kwargs.get("device") + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # The first argument is a dtype + pass + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + # check if v is a floating point + if torch.is_floating_point(v): + # cast and send to device + new_data[k] = v.to(*args, **kwargs) + elif device is not None: + new_data[k] = v.to(device=device) + else: + new_data[k] = v + self.data = new_data return self diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2ca5bf96f715..53220c3fe541 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -47,6 +47,7 @@ is_tensor, is_tf_tensor, is_torch_device, + is_torch_dtype, is_torch_tensor, reshape, squeeze, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index b2725b314825..b601d1e61b1f 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -123,6 +123,24 @@ def is_torch_device(x): return False if not is_torch_available() else _is_torch_device(x) +def _is_torch_dtype(x): + import torch + + if isinstance(x, str): + if hasattr(torch, x): + x = getattr(torch, x) + else: + return False + return isinstance(x, torch.dtype) + + +def is_torch_dtype(x): + """ + Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed. + """ + return False if not is_torch_available() else _is_torch_dtype(x) + + def _is_tensorflow(x): import tensorflow as tf diff --git a/tests/models/deit/test_feature_extraction_deit.py b/tests/models/deit/test_feature_extraction_deit.py index 03b869a967fc..32b107756e7e 100644 --- a/tests/models/deit/test_feature_extraction_deit.py +++ b/tests/models/deit/test_feature_extraction_deit.py @@ -84,6 +84,7 @@ def prepare_feat_extract_dict(self): class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): feature_extraction_class = DeiTFeatureExtractor if is_vision_available() else None + test_cast_dtype = True def setUp(self): self.feature_extract_tester = DeiTFeatureExtractionTester(self) diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 7b7c33a9642c..fe8d02480644 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -25,7 +25,15 @@ from huggingface_hub import HfFolder, delete_repo, set_access_token from requests.exceptions import HTTPError from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor -from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test +from transformers.testing_utils import ( + TOKEN, + USER, + check_json_file_has_correct_format, + get_tests_dir, + is_staging_test, + require_torch, + require_vision, +) from transformers.utils import is_torch_available, is_vision_available @@ -134,6 +142,8 @@ def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify class FeatureExtractionSavingTestMixin: + test_cast_dtype = None + def test_feat_extract_to_json_string(self): feat_extract = self.feature_extraction_class(**self.feat_extract_dict) obj = json.loads(feat_extract.to_json_string()) @@ -164,6 +174,41 @@ def test_init_without_params(self): feat_extract = self.feature_extraction_class() self.assertIsNotNone(feat_extract) + @require_torch + @require_vision + def test_cast_dtype_device(self): + if self.test_cast_dtype is not None: + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + + encoding = feature_extractor(image_inputs, return_tensors="pt") + # for layoutLM compatiblity + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float32) + + encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16) + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float16) + + encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16) + + with self.assertRaises(TypeError): + _ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu") + + # Try with text + image feature + encoding = feature_extractor(image_inputs, return_tensors="pt") + encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])}) + encoding = encoding.to(torch.float16) + + self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.pixel_values.dtype, torch.float16) + self.assertEqual(encoding.input_ids.dtype, torch.long) + class FeatureExtractorUtilTester(unittest.TestCase): def test_cached_files_are_used_when_internet_is_down(self):