-
Notifications
You must be signed in to change notification settings - Fork 33.5k
🚨Default to fast image processors for all models #41388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f48a47b
d5d5c58
dd505b5
6a1448f
a292900
63a255d
ef73759
b5e8b2e
f14ff3c
0306430
01cb815
49ec906
7dd5682
946cc5c
b0cb3e0
3b9e846
53de7a4
93d2c4d
feeec28
4a6b080
02402a0
9204b4c
1ed7c56
757e1f1
bf763b2
0799a0a
98ead2c
59234ee
54bf8e0
bf1a4b6
e3f130d
cc45a7e
3810196
34bfc74
a0c5c1a
8979645
6cc30f9
447b598
ac72ba2
d5bf14a
773342b
12c854c
12a01fd
7d7c6b2
fa94bcb
74492e5
9bd9da1
e4e36d9
1fd0cd5
1532913
1c21d90
d931a2b
2e8003a
e26adb7
572b26d
57fa154
2f6b12a
7769112
3d48ee1
7e0125a
cabac7f
2fff041
1613f29
08249a2
6210f0e
a48b577
c29f9b0
e24559c
f04e642
88610fc
6d56e56
04d4145
624aad6
587209c
21d2fc4
1602059
0bbe085
ff4e4c7
aade130
acdb89f
aca384b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -863,31 +863,43 @@ def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = Fals | |
| paired_grouped_values[paired_index][shape].append(paired_value) | ||
| grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1) | ||
|
|
||
| # Store structure size for nested inputs to handle empty sublists during reconstruction | ||
| if is_nested: | ||
| grouped_images_index["_num_sublists"] = len(normalized_images) | ||
|
|
||
| return grouped_images, *paired_grouped_values, grouped_images_index | ||
|
|
||
|
|
||
| def _reconstruct_nested_structure(indices, processed_images): | ||
| """Helper function to reconstruct a single level nested structure.""" | ||
| # Find the maximum outer index | ||
| max_outer_idx = max(idx[0] for idx in indices) | ||
|
|
||
| # Create the outer list | ||
| result = [None] * (max_outer_idx + 1) | ||
| # Get the number of sublists (handles empty sublists like in [[], [image]]) | ||
| num_sublists = indices.pop("_num_sublists", None) | ||
|
|
||
| # Group indices by outer index | ||
| nested_indices = defaultdict(list) | ||
| for i, j in indices: | ||
| nested_indices[i].append(j) | ||
|
|
||
| # Determine the number of outer sublists | ||
| if num_sublists is not None: | ||
| max_outer_idx = num_sublists - 1 | ||
| elif nested_indices: | ||
| max_outer_idx = max(nested_indices.keys()) | ||
| else: | ||
| return [] | ||
|
|
||
| # Create the result structure | ||
| result = [] | ||
| for i in range(max_outer_idx + 1): | ||
| if i in nested_indices: | ||
| if i not in nested_indices: | ||
| result.append([]) | ||
| else: | ||
| inner_max_idx = max(nested_indices[i]) | ||
| inner_list = [None] * (inner_max_idx + 1) | ||
| for j in range(inner_max_idx + 1): | ||
| if (i, j) in indices: | ||
| shape, idx = indices[(i, j)] | ||
| inner_list[j] = processed_images[shape][idx] | ||
| result[i] = inner_list | ||
| for j in nested_indices[i]: | ||
| shape, idx = indices[(i, j)] | ||
| inner_list[j] = processed_images[shape][idx] | ||
| result.append(inner_list) | ||
|
|
||
| return result | ||
|
|
||
|
|
@@ -908,6 +920,21 @@ def _iterate_items(items, is_nested: bool): | |
| yield i, item | ||
|
|
||
|
|
||
| def _get_device_from_images(images, is_nested: bool) -> "torch.device": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the processor is the one creating the torch tensor then I would suppose that there is a way to store the device in the data structure that it creates instead of having this function
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is mainly to avoid having to pass around a device argument to all group_image_by_shape calls, when it's easy to deduce it |
||
| """ | ||
| Get the device from the first non-empty element in a (potentially nested) list of images. | ||
|
|
||
| Handles cases like `images = [[], [image]]` where the first sublist may be empty. | ||
| """ | ||
| if is_nested: | ||
| for row in images: | ||
| if isinstance(row, torch.Tensor): | ||
| return row.device | ||
| if isinstance(row, list) and len(row) > 0: | ||
| return row[0].device | ||
| return images[0].device | ||
|
|
||
|
|
||
| def group_images_by_shape( | ||
| images: Union[list["torch.Tensor"], "torch.Tensor"], | ||
| *paired_inputs, | ||
|
|
@@ -945,17 +972,21 @@ def group_images_by_shape( | |
| """ | ||
| # If disable grouping is not explicitly provided, we favor disabling it if the images are on CPU, and enabling it otherwise. | ||
| if disable_grouping is None: | ||
| device = images[0][0].device if is_nested else images[0].device | ||
| device = _get_device_from_images(images, is_nested) | ||
| disable_grouping = device == "cpu" | ||
|
|
||
| if disable_grouping: | ||
| grouped_images_index = {key: (key, 0) for key, _ in _iterate_items(images, is_nested)} | ||
| if is_nested: | ||
| grouped_images_index["_num_sublists"] = len(images) | ||
|
|
||
| return ( | ||
| {key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)}, | ||
| *[ | ||
| {key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)} | ||
| for paired_list in paired_inputs | ||
| ], | ||
| {key: (key, 0) for key, _ in _iterate_items(images, is_nested)}, | ||
| grouped_images_index, | ||
| ) | ||
|
|
||
| # Handle single level nested structure | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,7 @@ | |
| class FuyuImageProcessorFast(BaseImageProcessorFast): | ||
| do_resize = True | ||
| size = {"height": 1080, "width": 1920} | ||
| patch_size = {"height": 30, "width": 30} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gosh I remember this patch size. good default |
||
| resample = PILImageResampling.BILINEAR | ||
| do_pad = True | ||
| padding_value = 1.0 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -152,6 +152,27 @@ def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, | |
| return (max_height, max_width) | ||
|
|
||
|
|
||
| def get_num_channels(images_list: list[list["torch.Tensor"]]) -> int: | ||
| """ | ||
| Get the number of channels across all images in a batch. Handle empty sublists like in [[], [image]]. | ||
| """ | ||
| for images in images_list: | ||
| if images: | ||
| return images[0].shape[0] | ||
|
|
||
| raise ValueError("No images found in the batch.") | ||
|
|
||
|
|
||
| def get_device_from_images(images_list: list[list["torch.Tensor"]]) -> "torch.device": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure why we need this when extracting the device should be exactly the same for every single image processor no?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some pass nested images to group_image_by_shape, and some have structures with empty lists, so this is needed for edge cases |
||
| """ | ||
| Get the device from the first non-empty element in a nested list of images. | ||
| Handle empty sublists like in [[], [image]]. | ||
| """ | ||
| for images in images_list: | ||
| if images: | ||
| return images[0].device | ||
|
|
||
|
|
||
| def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "torch.Tensor": | ||
| """ | ||
| Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. | ||
|
|
@@ -183,11 +204,14 @@ class Idefics3ImageProcessorFast(BaseImageProcessorFast): | |
| do_pad = True | ||
| return_row_col_info = False | ||
| valid_kwargs = Idefics3ImageProcessorKwargs | ||
| model_input_names = ["pixel_values", "pixel_attention_mask"] | ||
|
|
||
| def _prepare_images_structure(self, images: ImageInput, expected_ndims: int = 3) -> ImageInput: | ||
| """ | ||
| Prepare a nested images structure for processing. | ||
| """ | ||
| # Checks for `str` in case of URL/local path and optionally loads images | ||
| images = self.fetch_images(images) | ||
| return make_nested_list_of_images(images, expected_ndims=expected_ndims) | ||
|
|
||
| def resize( | ||
|
|
@@ -438,18 +462,20 @@ def _preprocess( | |
| # Get max images per batch | ||
| max_num_images = max(len(images_) for images_ in processed_images) | ||
| max_height, max_width = get_max_height_width(processed_images) | ||
| num_channels = get_num_channels(processed_images) | ||
| device = get_device_from_images(processed_images) | ||
|
|
||
| processed_images_padded = torch.zeros( | ||
| len(processed_images), | ||
| max_num_images, | ||
| *(processed_images[0][0].shape[0], max_height, max_width), | ||
| device=processed_images[0][0].device, | ||
| *(num_channels, max_height, max_width), | ||
| device=device, | ||
| ) | ||
| pixel_attention_masks = torch.zeros( | ||
| len(processed_images), | ||
| max_num_images, | ||
| *(max_height, max_width), | ||
| device=processed_images[0][0].device, | ||
| device=device, | ||
| ) | ||
| for i, images in enumerate(processed_images): | ||
| for j, image in enumerate(images): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clearer!