Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,15 @@ class BatchFeature(UserDict):
initialization.
"""

def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
def __init__(
self,
data: Optional[Dict[str, Any]] = None,
tensor_type: Union[None, str, TensorType] = None,
float_precision: Optional[str] = None,
):
super().__init__(data)
self.convert_to_tensors(tensor_type=tensor_type)
self.cast_to_dtype(tensor_type=tensor_type, float_precision=float_precision)

def __getitem__(self, item: str) -> Union[Any]:
"""
Expand Down Expand Up @@ -109,6 +115,101 @@ def values(self):
def items(self):
return self.data.items()

def cast_to_dtype(
self, tensor_type: Optional[Union[str, TensorType]] = None, float_precision: Optional[str] = None
):
"""
Maybe cast the input tensors (floating point tensors only) to the desired precision

Args:
tensor_type (`str` or [`~utils.TensorType`], *optional*):
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
`None`, no modification is done.
float_precision (`str`, *optional*):
The output floating point precision [float16, float32, double, bfloat16]
"""
if (float_precision is None) or (tensor_type is None):
return self

# Convert to TensorType
# Convert to TensorType
tensor_type = TensorType(tensor_type)

# Get a function reference for the correct framework
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available():
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
import tensorflow as tf

target_framework = tf
cast_fun = tf.cast

def is_floating(x):
return x.dtype in (tf.float16, tf.float32, tf.double, tf.bfloat16)

is_tensor = tf.is_tensor

elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch

target_framework = torch

def cast_fun(x, dtype):
return x.to(dtype=dtype)

def is_floating(x):
return x.dtype in (torch.float16, torch.float32, torch.double, torch.bfloat16)

is_tensor = torch.is_tensor

# Jax tensors
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
import jax.numpy as jnp # noqa: F811

target_framework = jnp

def cast_fun(x, dtype):
return x.astype(dtype=dtype)

def is_floating(x):
return x.dtype in (jnp.float16, jnp.float32, jnp.double, jnp.bfloat16)

is_tensor = is_jax_tensor
# np arrays
else:
target_framework = np

def cast_fun(x, dtype):
return x.astype(dtype=dtype)

def is_floating(x):
return x.dtype in (np.half, np.single, np.double, np.longdouble)

is_tensor = is_numpy_array

if hasattr(target_framework, float_precision):
target_dtype = getattr(target_framework, float_precision)
else:
raise ValueError(
f"Failed to import the `dtype` {target_dtype} from the framework {target_framework} - please use a"
" supported `dtype` for your targeted framework.",
)

# Do the tensor conversion in batch
for key, value in self.items():
# sanity check that we check for only tensors
if is_tensor(value):
if is_floating(value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very comfortable to call these is_tensor and if_floating without checking if value is from target_framework.

For example, how about tensor_type being pt but value is a tf tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I guess this would not happen since the test is already done on the convert_to_tensors function that is called right before

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you @younesbelkada . However, this method is added like a public method, so the concern is there (despite I doubt any user will use it). If it is prefixed with _, I won't complain at all :-)

Let @sgugger review and give us his opinion if we should make any effort on such things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see now! Yes then makes sense to have it prefixed with _ 💪

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion has been stated above. I don't think any of this is useful as Flax and TensorFlow deal differently with different dtypes, and there should only be a slight adaptation of the to method.

self[key] = cast_fun(value, target_dtype)

return self

def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
"""
Convert the inner content to tensors.
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/vit/image_processing_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def preprocess(
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
float_precision: Optional[str] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -231,6 +232,8 @@ def preprocess(
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
float_precision (`str`, *optional*):
The output floating point precision [float16, float32, double, bfloat16]
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
Expand Down Expand Up @@ -273,4 +276,4 @@ def preprocess(
images = [to_channel_dimension_format(image, data_format) for image in images]

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
return BatchFeature(data=data, tensor_type=return_tensors, float_precision=float_precision)