Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@
is_tf_available,
is_torch_available,
is_torch_device,
is_torch_dtype,
logging,
torch_required,
)


if TYPE_CHECKING:
if is_torch_available():
import torch
import torch # noqa


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -138,7 +139,7 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
import torch # noqa

def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
Expand Down Expand Up @@ -175,25 +176,47 @@ def as_tensor(value):
return self

@torch_required
# Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchFeature
def to(self, device: Union[str, "torch.device"]) -> "BatchFeature":
def to(self, *args, **kwargs) -> "BatchFeature":
"""
Send all values to device by calling `v.to(device)` (PyTorch only).
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.

Args:
device (`str` or `torch.device`): The device to put the tensors on.
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.

Returns:
[`BatchFeature`]: The same instance after modification.
"""

# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")
import torch # noqa

new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
return self


Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
is_tensor,
is_tf_tensor,
is_torch_device,
is_torch_dtype,
is_torch_tensor,
reshape,
squeeze,
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,24 @@ def is_torch_device(x):
return False if not is_torch_available() else _is_torch_device(x)


def _is_torch_dtype(x):
import torch

if isinstance(x, str):
if hasattr(torch, x):
x = getattr(torch, x)
else:
return False
return isinstance(x, torch.dtype)


def is_torch_dtype(x):
"""
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch_dtype(x)


def _is_tensorflow(x):
import tensorflow as tf

Expand Down
1 change: 1 addition & 0 deletions tests/models/deit/test_feature_extraction_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def prepare_feat_extract_dict(self):
class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):

feature_extraction_class = DeiTFeatureExtractor if is_vision_available() else None
test_cast_dtype = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it needed here for DeiT and not other models?

Copy link
Contributor Author

@younesbelkada younesbelkada Dec 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just waited for everyone's approval on the concept and planned to add it on more models! Probably on the main models only as it may slow down the CI test suite


def setUp(self):
self.feature_extract_tester = DeiTFeatureExtractionTester(self)
Expand Down
47 changes: 46 additions & 1 deletion tests/test_feature_extraction_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@
from huggingface_hub import HfFolder, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
from transformers.testing_utils import (
TOKEN,
USER,
check_json_file_has_correct_format,
get_tests_dir,
is_staging_test,
require_torch,
require_vision,
)
from transformers.utils import is_torch_available, is_vision_available


Expand Down Expand Up @@ -134,6 +142,8 @@ def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify


class FeatureExtractionSavingTestMixin:
test_cast_dtype = None

def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
obj = json.loads(feat_extract.to_json_string())
Expand Down Expand Up @@ -164,6 +174,41 @@ def test_init_without_params(self):
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)

@require_torch
@require_vision
def test_cast_dtype_device(self):
if self.test_cast_dtype is not None:
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)

# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)

encoding = feature_extractor(image_inputs, return_tensors="pt")
# for layoutLM compatiblity
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float32)

encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16)
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float16)

encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16)

with self.assertRaises(TypeError):
_ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")

# Try with text + image feature
encoding = feature_extractor(image_inputs, return_tensors="pt")
encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
encoding = encoding.to(torch.float16)

self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
self.assertEqual(encoding.input_ids.dtype, torch.long)


class FeatureExtractorUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
Expand Down