Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/transformers/models/clip/feature_extraction_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
convert_rgb (`bool`, defaults to `True`):
Whether or not to convert `PIL.Image.Image` into `RGB` format
"""

model_input_names = ["pixel_values"]
Expand All @@ -68,6 +70,7 @@ def __init__(
do_normalize=True,
image_mean=None,
image_std=None,
do_convert_rgb=True,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -79,6 +82,7 @@ def __init__(
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
self.do_convert_rgb = do_convert_rgb

def __call__(
self,
Expand Down Expand Up @@ -141,7 +145,9 @@ def __call__(
if not is_batched:
images = [images]

# transformations (resizing + center cropping + normalization)
# transformations (convert rgb + resizing + center cropping + normalization)
if self.do_convert_rgb:
images = [self.convert_rgb(image) for image in images]
if self.do_resize and self.size is not None and self.resample is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_center_crop and self.crop_size is not None:
Expand All @@ -155,6 +161,20 @@ def __call__(

return encoded_inputs

def convert_rgb(self, image):
"""
Converts `image` to RGB format. Note that this will trigger a conversion of `image` to a PIL Image.

Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to convert.
"""
self._ensure_format_supported(image)
if not isinstance(image, Image.Image):
return image

return image.convert("RGB")

def center_crop(self, image, size):
"""
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
Expand Down
65 changes: 65 additions & 0 deletions tests/models/clip/test_feature_extraction_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -63,6 +64,7 @@ def __init__(
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb

def prepare_feat_extract_dict(self):
return {
Expand All @@ -73,6 +75,7 @@ def prepare_feat_extract_dict(self):
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
}

def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False):
Expand Down Expand Up @@ -128,6 +131,7 @@ def test_feat_extract_properties(self):
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))

def test_batch_feature(self):
pass
Expand Down Expand Up @@ -227,3 +231,64 @@ def test_call_pytorch(self):
self.feature_extract_tester.crop_size,
),
)


@require_torch
@require_vision
class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, unittest.TestCase):

feature_extraction_class = CLIPFeatureExtractor if is_vision_available() else None

def setUp(self):
self.feature_extract_tester = CLIPFeatureExtractionTester(self, num_channels=4)
self.expected_encoded_image_num_channels = 3

@property
def feat_extract_dict(self):
return self.feature_extract_tester.prepare_feat_extract_dict()

def test_feat_extract_properties(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "center_crop"))
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))

def test_batch_feature(self):
pass

def test_call_pil_four_channels(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PIL images
image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)

# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.expected_encoded_image_num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)

# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.expected_encoded_image_num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)