Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import torch

if is_torchvision_available():
from torchvision.transforms.v2 import functional as F
import torchvision.transforms.v2.functional as TVF

from .image_utils import pil_torch_interpolation_mapping

Expand All @@ -82,7 +82,7 @@ def validate_fast_preprocess_arguments(
crop_size: SizeDict | None = None,
do_resize: bool | None = None,
size: SizeDict | None = None,
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
return_tensors: str | TensorType | None = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
):
Expand Down Expand Up @@ -398,7 +398,7 @@ def pad(
)
if image_size != pad_size:
padding = (0, 0, padding_width, padding_height)
stacked_images = F.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
stacked_images = TVF.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
processed_images_grouped[shape] = stacked_images

if return_mask:
Expand All @@ -418,7 +418,7 @@ def resize(
self,
image: "torch.Tensor",
size: SizeDict,
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
antialias: bool = True,
**kwargs,
) -> "torch.Tensor":
Expand All @@ -438,7 +438,7 @@ def resize(
Returns:
`torch.Tensor`: The resized image.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
interpolation = interpolation if interpolation is not None else TVF.InterpolationMode.BILINEAR
if size.shortest_edge and size.longest_edge:
# Resize the image so that the shortest edge or the longest edge is of the given size
# while maintaining the aspect ratio of the original image.
Expand Down Expand Up @@ -468,31 +468,31 @@ def resize(
# TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
if is_torchdynamo_compiling() and is_rocm_platform():
return self.compile_friendly_resize(image, new_size, interpolation, antialias)
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
return TVF.resize(image, new_size, interpolation=interpolation, antialias=antialias)

@staticmethod
def compile_friendly_resize(
image: "torch.Tensor",
new_size: tuple[int, int],
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
antialias: bool = True,
) -> "torch.Tensor":
"""
A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
A wrapper around `TVF.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
"""
if image.dtype == torch.uint8:
# 256 is used on purpose instead of 255 to avoid numerical differences
# see https://github.com/huggingface/transformers/pull/38540#discussion_r2127165652
image = image.float() / 256
image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
image = TVF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
image = image * 256
# torch.where is used on purpose instead of torch.clamp to avoid bug in torch.compile
# see https://github.com/huggingface/transformers/pull/38540#discussion_r2126888471
image = torch.where(image > 255, 255, image)
image = torch.where(image < 0, 0, image)
image = image.round().to(torch.uint8)
else:
image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
image = TVF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
return image

def rescale(
Expand Down Expand Up @@ -536,7 +536,7 @@ def normalize(
Returns:
`torch.Tensor`: The normalized image.
"""
return F.normalize(image, mean, std)
return TVF.normalize(image, mean, std)

@lru_cache(maxsize=10)
def _fuse_mean_std_and_rescale_factor(
Expand Down Expand Up @@ -615,14 +615,14 @@ def center_crop(
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
image = F.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0
image = TVF.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0
image_height, image_width = image.shape[-2:]
if crop_width == image_width and crop_height == image_height:
return image

crop_top = int((image_height - crop_height) / 2.0)
crop_left = int((image_width - crop_width) / 2.0)
return F.crop(image, crop_top, crop_left, crop_height, crop_width)
return TVF.crop(image, crop_top, crop_left, crop_height, crop_width)

def convert_to_rgb(
self,
Expand Down Expand Up @@ -687,9 +687,9 @@ def _process_image(
image = self.convert_to_rgb(image)

if image_type == ImageType.PIL:
image = F.pil_to_tensor(image)
image = TVF.pil_to_tensor(image)
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
# not using TVF.to_tensor as it doesn't handle (C, H, W) numpy arrays
image = torch.from_numpy(image).contiguous()

# If the image is 2D, we need to unsqueeze it to add a channel dimension for processing
Expand Down Expand Up @@ -813,7 +813,7 @@ def _validate_preprocess_kwargs(
size: SizeDict | None = None,
do_center_crop: bool | None = None,
crop_size: SizeDict | None = None,
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
return_tensors: str | TensorType | None = None,
data_format: ChannelDimension | None = None,
**kwargs,
Expand Down Expand Up @@ -892,7 +892,7 @@ def _preprocess(
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
interpolation: Optional["TVF.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/beit/image_processing_beit_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Optional, Union

import torch
from torchvision.transforms.v2 import functional as F
import torchvision.transforms.v2.functional as TVF

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
Expand Down Expand Up @@ -124,7 +124,7 @@ def _preprocess(
do_reduce_labels: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
interpolation: Optional["TVF.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional

import torch
from torchvision.transforms.v2 import functional as F
import torchvision.transforms.v2.functional as TVF

from ...image_processing_utils_fast import (
BaseImageProcessorFast,
Expand Down Expand Up @@ -113,7 +113,7 @@ def resize(
image: "torch.Tensor",
size: SizeDict,
size_divisor: int = 32,
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
antialias: bool = True,
**kwargs,
) -> "torch.Tensor":
Expand All @@ -137,7 +137,7 @@ def resize(
Returns:
`torch.Tensor`: The resized image.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
interpolation = interpolation if interpolation is not None else TVF.InterpolationMode.BILINEAR
if not size.shortest_edge:
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size.shortest_edge
Expand Down Expand Up @@ -172,7 +172,7 @@ def center_crop(
Size of the output image in the form `{"height": h, "width": w}`.
"""
output_size = size.shortest_edge
return F.center_crop(
return TVF.center_crop(
image,
output_size=(output_size, output_size),
**kwargs,
Expand All @@ -193,7 +193,7 @@ def _pad_image(
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = (0, 0, pad_right, pad_bottom)
padded_image = F.pad(
padded_image = TVF.pad(
image,
padding,
fill=constant_values,
Expand All @@ -206,7 +206,7 @@ def _preprocess(
do_resize: bool,
size: SizeDict,
size_divisor: int | None,
interpolation: Optional["F.InterpolationMode"],
interpolation: Optional["TVF.InterpolationMode"],
do_pad: bool,
do_center_crop: bool,
crop_size: SizeDict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import PIL
import torch
from torchvision.transforms.v2 import functional as F
import torchvision.transforms.v2.functional as TVF

from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_utils import ImageInput, PILImageResampling, SizeDict
Expand Down Expand Up @@ -74,7 +74,7 @@ def resize(
self,
image: "torch.Tensor",
size: SizeDict,
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
**kwargs,
) -> "torch.Tensor":
"""
Expand All @@ -91,14 +91,14 @@ def resize(
Returns:
`torch.Tensor`: The resized image.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if interpolation == F.InterpolationMode.LANCZOS:
interpolation = interpolation if interpolation is not None else TVF.InterpolationMode.BILINEAR
if interpolation == TVF.InterpolationMode.LANCZOS:
logger.warning_once(
"You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
"BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
"want full consistency with the original model."
)
interpolation = F.InterpolationMode.BICUBIC
interpolation = TVF.InterpolationMode.BICUBIC

return super().resize(
image=image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np
import torch
from torchvision.transforms.v2 import functional as F
import torchvision.transforms.v2.functional as TVF

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
Expand Down Expand Up @@ -136,7 +136,7 @@ def crop_image_to_patches(
max_patches: int,
use_thumbnail: bool = True,
patch_size: tuple | int | dict | None = None,
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
):
"""
Crop the images to patches and return a list of cropped images.
Expand Down Expand Up @@ -207,7 +207,7 @@ def _preprocess(
crop_to_patches: bool,
min_patches: int,
max_patches: int,
interpolation: Optional["F.InterpolationMode"],
interpolation: Optional["TVF.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import Any, Optional

import torch
import torchvision.transforms.v2.functional as TVF
from torch import nn
from torchvision.io import read_image
from torchvision.transforms.v2 import functional as F

from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_processing_utils_fast import (
Expand Down Expand Up @@ -313,7 +313,7 @@ def resize(
self,
image: torch.Tensor,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -336,7 +336,7 @@ def resize(
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
Resampling filter to use if resizing the image.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
interpolation = interpolation if interpolation is not None else TVF.InterpolationMode.BILINEAR
if size.shortest_edge and size.longest_edge:
# Resize the image so that the shortest edge or the longest edge is of the given size
# while maintaining the aspect ratio of the original image.
Expand All @@ -355,7 +355,7 @@ def resize(
f" {size.keys()}."
)

image = F.resize(
image = TVF.resize(
image,
size=new_size,
interpolation=interpolation,
Expand All @@ -369,7 +369,7 @@ def resize_annotation(
orig_size: tuple[int, int],
target_size: tuple[int, int],
threshold: float = 0.5,
interpolation: Optional["F.InterpolationMode"] = None,
interpolation: Optional["TVF.InterpolationMode"] = None,
):
"""
Resizes an annotation to a target size.
Expand All @@ -383,10 +383,10 @@ def resize_annotation(
The target size of the image, as returned by the preprocessing `resize` step.
threshold (`float`, *optional*, defaults to 0.5):
The threshold used to binarize the segmentation masks.
resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`):
resample (`InterpolationMode`, defaults to `TVF.InterpolationMode.NEAREST_EXACT`):
The resampling filter to use when resizing the masks.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST_EXACT
interpolation = interpolation if interpolation is not None else TVF.InterpolationMode.NEAREST_EXACT
ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]

new_annotation = {}
Expand All @@ -405,7 +405,7 @@ def resize_annotation(
new_annotation["area"] = scaled_area
elif key == "masks":
masks = value[:, None]
masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
masks = [TVF.resize(mask, target_size, interpolation=interpolation) for mask in masks]
masks = torch.stack(masks).to(torch.float32)
masks = masks[:, 0] > threshold
new_annotation["masks"] = masks
Expand Down Expand Up @@ -449,7 +449,7 @@ def _update_annotation_for_padded_image(
for key, value in annotation.items():
if key == "masks":
masks = value
masks = F.pad(
masks = TVF.pad(
masks,
padding,
fill=0,
Expand Down Expand Up @@ -484,7 +484,7 @@ def pad(
)
if original_size != padded_size:
padding = [0, 0, padding_right, padding_bottom]
image = F.pad(image, padding, fill=fill)
image = TVF.pad(image, padding, fill=fill)
if annotation is not None:
annotation = self._update_annotation_for_padded_image(
annotation, original_size, padded_size, padding, update_bboxes
Expand All @@ -504,7 +504,7 @@ def _preprocess(
return_segmentation_masks: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
interpolation: Optional["TVF.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Optional

import torch
from torchvision.transforms.v2 import functional as F
import torchvision.transforms.v2.functional as TVF

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
Expand Down Expand Up @@ -121,7 +121,7 @@ def _preprocess(
do_resize: bool,
size: dict[str, int],
crop_pct: float,
interpolation: Optional["F.InterpolationMode"],
interpolation: Optional["TVF.InterpolationMode"],
do_center_crop: bool,
crop_size: int,
do_rescale: bool,
Expand Down
Loading