diff --git a/src/transformers/models/imagebind/image_processing_imagebind.py b/src/transformers/models/imagebind/image_processing_imagebind.py index ed20d8fa9e76..b787f0572697 100644 --- a/src/transformers/models/imagebind/image_processing_imagebind.py +++ b/src/transformers/models/imagebind/image_processing_imagebind.py @@ -13,10 +13,12 @@ # limitations under the License. """Image processor class for ImageBind.""" +import math from fractions import Fraction from typing import Dict, List, Optional, Tuple, Union import numpy as np +import torch from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( @@ -46,7 +48,6 @@ logger = logging.get_logger(__name__) - if is_vision_available(): import PIL @@ -77,19 +78,42 @@ def uniform_chunk_sampling( Args: total_duration (float): Total duration of the audio/video. - chunk_duration (float): Duration of each chunk. - num_chunks (int): Number of chunks to sample. + chunk_duration (float): Duration of each chunk(clip duration). + num_chunks (int): Number of chunks to sample(number of clips per video). Returns: List[Tuple[float, float]]: List of tuples where each tuple contains the start and end time of a chunk. """ + _current_clip_index = 0 + _current_aug_index = 0 + _augs_per_clip: int = 1 + chunk_duration_fraction = Fraction(chunk_duration) - max_possible_clip_start = Fraction(max(total_duration - chunk_duration, 0)) + max_possible_clip_start = Fraction( + max(total_duration - chunk_duration_fraction, 0) + ) # Previously chunk_duration was used instead of chunk_duration_fraction so that could be the reason for pixel values not matching uniform_clip = Fraction(max_possible_clip_start / max(num_chunks - 1, 1)) result = [] - for clip_index in range(num_chunks): - clip_start_sec = uniform_clip * clip_index + is_last_clip = False + while not is_last_clip: + clip_start_sec = uniform_clip * _current_clip_index + _current_aug_index += 1 + if _current_aug_index >= _augs_per_clip: + _current_clip_index += 1 + _current_aug_index = 0 + + # Last clip is True if sampled self._clips_per_video or if end of video is reached. + is_last_clip = False + if _current_clip_index >= num_chunks or uniform_clip * _current_clip_index > max_possible_clip_start: + _current_clip_index = 0 + is_last_clip = True + + # reset + if is_last_clip: + _current_clip_index = 0 + _current_aug_index = 0 + clip_end_sec = clip_start_sec + chunk_duration_fraction result.append((clip_start_sec, clip_end_sec)) @@ -109,13 +133,13 @@ def uniform_temporal_subsample(video: VideoInput, num_samples: int) -> VideoInpu num_samples (`int`): Number of frames to sample. """ - num_frames = len(video) - + num_frames = video.shape[-3] # len(video) gives first element of size tensor which is channels instead of frames + assert num_samples > 0 and num_frames > 0 # Sample by nearest neighbor interpolation if num_samples > t. indices = np.linspace(0, num_frames - 1, num_samples) indices = np.clip(indices, 0, num_frames - 1).astype(int) - return [video[i] for i in indices] + return video[:, indices, :, :] # second index has frames(slicing instead of looping) class ImageBindImageProcessor(BaseImageProcessor): @@ -130,7 +154,7 @@ class ImageBindImageProcessor(BaseImageProcessor): Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` method. - resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. do_center_crop (`bool`, *optional*, defaults to `True`): Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the @@ -158,13 +182,15 @@ class ImageBindImageProcessor(BaseImageProcessor): do_chunk (`bool`, *optional*, defaults to `False`): Whether to chunk the video into multiple clips. chunk_duration (`float`, *optional*, defaults to 2.0): - Duration of each chunk in seconds. + Duration of each chunk in seconds(clip duration). num_chunks (`int`, *optional*, defaults to 5): - Number of chunks to sample. + Number of chunks to sample(number of clips per video). num_frames_per_chunk (`int`, *optional*, defaults to 2): Number of frames to sample per chunk. - fps (`int`, *optional*, defaults to 30): + fps (`List[int]`, *optional*, defaults to `[30]`): Frame rate of the video. It's assumed that all videos have the same frame rate. + Durations of videos + duration (`List`, *optional*, defaults to `[10.0]`): """ model_input_names = ["pixel_values"] @@ -173,7 +199,7 @@ def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, - resample: PILImageResampling = PILImageResampling.BICUBIC, + resample: PILImageResampling = PILImageResampling.BILINEAR, do_center_crop: bool = True, crop_size: Dict[str, int] = None, do_rescale: bool = True, @@ -186,7 +212,8 @@ def __init__( chunk_duration: float = 2.0, num_chunks: int = 5, num_frames_per_chunk: int = 2, - fps: int = 30, + fps: List[int] = [30], + duration: List[float] = [10.0], **kwargs, ) -> None: super().__init__(**kwargs) @@ -211,6 +238,7 @@ def __init__( self.num_chunks = num_chunks self.num_frames_per_chunk = num_frames_per_chunk self.fps = fps + self.duration = duration self._valid_processor_keys = [ "images", "do_resize", @@ -228,6 +256,7 @@ def __init__( "chunk_duration", "num_chunks", "fps", + "duration", "return_tensors", "data_format", "input_data_format", @@ -246,7 +275,7 @@ def resize( self, image: np.ndarray, size: Dict[str, int], - resample: PILImageResampling = PILImageResampling.BICUBIC, + resample: PILImageResampling = PILImageResampling.BILINEAR, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, @@ -260,7 +289,7 @@ def resize( Image to resize. size (`Dict[str, int]`): Size of the output image. - resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): Resampling filter to use when resiizing the image. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. @@ -291,8 +320,146 @@ def resize( **kwargs, ) + # Adapted from https://github.com/facebookresearch/pytorchvideo/blob/1fadaef40dd393ca09680f55582399f4679fc9b7/pytorchvideo/transforms/functional.py#L92 + def short_side_scale( + self, + image: np.ndarray, + size: int = 224, + resample: str = "bilinear", + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Determines the shorter spatial dim of the video (i.e. width or height) and scales + it to the given size. To maintain aspect ratio, the longer side is then scaled + accordingly. + Args: + image (np.ndarray): A video tensor of shape (C, T, H, W) and type numpy.float32. + size (int): The size the shorter side is scaled to. + resample (str): Algorithm used for upsampling, + options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + Returns: + An image-like numpy array with scaled spatial dims. + """ # noqa + assert len(image.shape) == 4 + assert image.dtype == np.float32 + _, _, h, w = image.shape + if w < h: + new_h = int(math.floor((float(h) / w) * size)) + new_w = size + else: + new_h = size + new_w = int(math.floor((float(w) / h) * size)) + + data_format = input_data_format if data_format is None else data_format + resized_image = torch.nn.functional.interpolate( + torch.tensor(image).contiguous(), size=(new_h, new_w), mode=resample, align_corners=False + ).numpy() + # input image in always in FIRST channel dim + resized_image = np.array( + [ + to_channel_dimension_format(img, data_format, input_channel_dim=ChannelDimension.FIRST) + for img in resized_image + ] + ) + return resized_image + + def uniform_crop( + self, + images: np.ndarray, + crop_size: int = 224, + num_crops: int = 3, + scale_size=None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> List[np.ndarray]: + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (np.ndarray): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + crop_size (int): size of height/weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + scale_size (int): optional. If not None, resize the images to scale_size before + performing any crop. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + Returns: + cropped (List[np.ndarray]): images with dimension of + `num frames` x `channel` x `size` x `size`. + """ + data_format = input_data_format if data_format is None else data_format + + crop_size = crop_size["height"] + uniform_cropped = [] + if num_crops == 3: + crops_to_ext = [0, 1, 2] + elif num_crops == 1: + crops_to_ext = [1] + for spatial_idx in crops_to_ext: + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - crop_size) / 2)) + x_offset = int(math.ceil((width - crop_size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - crop_size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - crop_size + cropped = images[:, :, y_offset : y_offset + crop_size, x_offset : x_offset + crop_size] + if ndim == 3: + cropped = cropped.squeeze(0) + # input image in always in FIRST channel dim + cropped = np.array( + [ + to_channel_dimension_format(img, data_format, input_channel_dim=ChannelDimension.FIRST) + for img in cropped + ] + ) + + uniform_cropped.append(cropped) + + return uniform_cropped + def chunk( - self, video: VideoInput, fps: int, chunk_duration: float, num_chunks: int, num_frames_per_chunk: int + self, + video: VideoInput, + fps: int, + duration: float, + chunk_duration: float, + num_chunks: int, + num_frames_per_chunk: int, ) -> List[VideoInput]: """ Uniformly sample `num_chunks` chunks of duration `chunk_duration` from a video. @@ -302,14 +469,17 @@ def chunk( Video to chunk. fps (`int`): Frame rate of the video + duration('float', *optional*, defaults to 10.0): + Durations of videos chunk_duration (`float`): - Duration of each chunk. + Duration of each chunk(clip duration). num_chunks (`int`): - Number of chunks to sample. + Number of chunks to sample(number of clips per video). num_frames_per_chunk (`int`): Number of frames to sample per chunk. """ - video_duration = len(video) / fps + fps = float(fps) + video_duration = duration if video_duration < chunk_duration: logger.warning_once( "Chunk duration is greater than audio duration. Chunks will be repeated, consider adjusting either `chunk_duration` or `num_chunks`" @@ -320,8 +490,18 @@ def chunk( all_clips = [] for clip_timepoints in all_clips_timepoints: - video_clip = video[int(clip_timepoints[0] * fps) : int(clip_timepoints[1] * fps)] - video_clip = uniform_temporal_subsample(video_clip, num_samples=num_frames_per_chunk) + # shape of video tensor is (Channel X Frames X Height X Width) so frames dim is accessed at 1 index + + start_idx = math.ceil(fps * clip_timepoints[0]) + end_idx = math.ceil(fps * clip_timepoints[1]) + end_idx = min(end_idx, int(duration * fps)) + frame_idxs = list(range(start_idx, end_idx)) + frame_idxs = torch.tensor(frame_idxs).contiguous() + video_clip = video[:, frame_idxs, :, :] + if video_clip is None: + raise ValueError("No clip found") + video_clip = uniform_temporal_subsample(video_clip.numpy(), num_samples=num_frames_per_chunk) + video_clip = video_clip / 255.0 # since this is float, need 0-1 all_clips.append(video_clip) return all_clips @@ -330,6 +510,7 @@ def chunk( def _preprocess_image( self, images: ImageInput, + is_video: bool = None, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, @@ -362,40 +543,34 @@ def _preprocess_image( # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] - if is_scaled_image(images[0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) - if input_data_format is None: # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) - if do_resize: - images = [ - self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) - for image in images - ] - - if do_center_crop: - images = [ - self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images - ] + images = self.short_side_scale(image=np.array(images), input_data_format=input_data_format) if do_rescale: images = [ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) for image in images ] - + images = ( + torch.tensor(images).permute(1, 0, 2, 3).numpy() + ) # to interchange channel and frame dim for normalize func as mean and std have shape 3 if do_normalize: images = [ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) for image in images ] + if do_center_crop: + images = self.uniform_crop(np.array(images), crop_size, num_crops=3, input_data_format=input_data_format) + images = [ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images ] @@ -422,7 +597,8 @@ def preprocess( chunk_duration: float = None, num_chunks: int = None, num_frames_per_chunk: int = None, - fps: int = None, + fps: List[int] = None, + duration: List[float] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -467,14 +643,16 @@ def preprocess( Whether to convert the image to RGB. do_chunk (`bool`, *optional*, defaults to `self.do_chunk`): Whether to chunk the video into multiple clips. - chunk_duration (`float`, *optional*, defaults to `self.chunk_duration`): - Duration of each chunk in seconds. + chunk_duration (`int`, *optional*, defaults to `self.chunk_duration`): + Duration of each chunk in seconds(clip duration). num_chunks (`int`, *optional*, defaults to `self.num_chunks`): - Number of chunks to sample. + Number of chunks to sample(number of clips per video). num_frames_per_chunk (`int`, *optional*, defaults to `self.num_frames_per_chunk`): Number of frames to sample per chunk. - fps (`int`, *optional*, defaults to `self.fps`): + fps (`List[int]`, *optional*, defaults to `self.fps`): Frame rate of the video. It's assumed that all videos have the same frame rate. + duration('List[float]', *optional*, defaults to [10.0]): + Durations of videos return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -518,10 +696,13 @@ def preprocess( num_chunks = num_chunks if num_chunks is not None else self.num_chunks num_frames_per_chunk = num_frames_per_chunk if num_frames_per_chunk is not None else self.num_frames_per_chunk fps = fps if fps is not None else self.fps + duration = duration if duration is not None else self.duration if images is not None: + is_video = False images = make_list_of_images(images) if videos is not None: + is_video = True videos = make_batched_videos(videos) validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) @@ -535,6 +716,7 @@ def preprocess( if images is not None: pixel_values = self._preprocess_image( images=images, + is_video=is_video, do_resize=do_resize, size=size, resample=resample, @@ -551,11 +733,13 @@ def preprocess( ) else: pixel_values = [] - for video in videos: + + for idx, video in enumerate(videos): if do_chunk: clips = self.chunk( - video=video, - fps=fps, + video=video[0], + fps=fps[idx], + duration=duration[idx], chunk_duration=chunk_duration, num_chunks=num_chunks, num_frames_per_chunk=num_frames_per_chunk, @@ -564,6 +748,7 @@ def preprocess( _pixel_values = [ self._preprocess_image( images=clip, + is_video=is_video, do_resize=do_resize, size=size, resample=PILImageResampling.BILINEAR, @@ -584,6 +769,7 @@ def preprocess( _pixel_values = [ self._preprocess_image( images=video, + is_video=is_video, do_resize=do_resize, size=size, resample=resample, @@ -599,11 +785,17 @@ def preprocess( input_data_format=input_data_format, ) ] - - # Avoid List[List[List[np.ndarray]]] - _pixel_values = np.stack(_pixel_values) - # Make it shape (num_chunks, num_channels, num_frames_per_chunk, height, width) - _pixel_values = np.swapaxes(_pixel_values, 1, 2) + _pixel_values = np.stack(np.array(_pixel_values)) + # Exchange frames and channels dim + _pixel_values = np.swapaxes(_pixel_values, 2, 3) pixel_values.append(_pixel_values) - + pixel_values = np.stack(pixel_values) + # Combine the second and third dimensions for merging num_crops in one dim + pixel_values_shape = pixel_values.shape + pixel_values_shape = ( + pixel_values_shape[0], + pixel_values_shape[1] * pixel_values_shape[2], + *pixel_values_shape[3:], + ) + pixel_values = pixel_values.reshape(pixel_values_shape) return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors) diff --git a/src/transformers/models/imagebind/processing_imagebind.py b/src/transformers/models/imagebind/processing_imagebind.py index 1d8162852d24..fa79abb3d8a5 100644 --- a/src/transformers/models/imagebind/processing_imagebind.py +++ b/src/transformers/models/imagebind/processing_imagebind.py @@ -31,7 +31,7 @@ class ImageBindProcessorKwargs(ProcessingKwargs, total=False): class ImageBindProcessor(ProcessorMixin): r""" - Constructs a ImageBind processor which wraps a ImageBind image processor and feature extracotr and a CLIP tokenizer into a single processor. + Constructs a ImageBind processor which wraps a ImageBind image processor and feature extractor and a CLIP tokenizer into a single processor. [`ImageBindProcessor`] offers all the functionalities of [`ImageBindImageProcessor`], [`ImageBindFeatureExtractor`] and [`CLIPTokenizerFast`]. See the [`~ImageBindProcessor.__call__`] and [`~ImageBindProcessor.decode`] for more information.