Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
"{processor_class}": "FakeProcessorClass",
"{model_class}": "FakeModelClass",
"{object_class}": "FakeObjectClass",
}
}
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
[[autodoc]] PixtralImageProcessor
- preprocess

## PixtralImageProcessorFast

[[autodoc]] PixtralImageProcessorFast
- preprocess

## PixtralProcessor

[[autodoc]] PixtralProcessor
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,7 @@
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
_import_structure["models.vit"].append("ViTImageProcessorFast")

Expand Down Expand Up @@ -6189,6 +6190,7 @@
from .image_processing_utils_fast import BaseImageProcessorFast
from .models.deformable_detr import DeformableDetrImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.pixtral import PixtralImageProcessorFast
from .models.rt_detr import RTDetrImageProcessorFast
from .models.vit import ViTImageProcessorFast

Expand Down
39 changes: 39 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .utils import (
ExplicitEnum,
TensorType,
is_jax_tensor,
is_numpy_array,
is_tf_tensor,
Expand Down Expand Up @@ -447,6 +448,44 @@ def validate_preprocess_arguments(
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")


def validate_fast_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
):
"""
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
"""
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")

if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")


# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
("paligemma", ("SiglipImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("poolformer", ("PoolFormerImageProcessor",)),
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
Expand Down
24 changes: 23 additions & 1 deletion src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_torchvision_available,
is_vision_available,
)


_import_structure = {
Expand Down Expand Up @@ -41,6 +47,14 @@
else:
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]

try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_pixtral_fast"] = ["PixtralImageProcessorFast"]


if TYPE_CHECKING:
from .configuration_pixtral import PixtralVisionConfig
Expand All @@ -65,6 +79,14 @@
else:
from .image_processing_pixtral import PixtralImageProcessor

try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_pixtral_fast import PixtralImageProcessorFast

else:
import sys

Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/pixtral/image_processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""Image processor class for Pixtral."""

import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -179,7 +180,7 @@ def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int])


def get_resize_output_image_size(
input_image: np.ndarray,
input_image: ImageInput,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
Expand All @@ -189,7 +190,7 @@ def get_resize_output_image_size(
size.

Args:
input_image (`np.ndarray`):
input_image (`ImageInput`):
The image to resize.
size (`int` or `Tuple[int, int]`):
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
Expand All @@ -210,8 +211,8 @@ def get_resize_output_image_size(

if ratio > 1:
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
height = int(np.ceil(height / ratio))
width = int(np.ceil(width / ratio))
height = int(math.ceil(height / ratio))
width = int(math.ceil(width / ratio))

num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
return num_height_tokens * patch_height, num_width_tokens * patch_width
Expand Down
Loading