diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 071348cb4330..7ced6b2879e8 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -304,9 +304,13 @@ def compile_friendly_resize( A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor. """ if image.dtype == torch.uint8: - image = image.float() / 255 + # 256 is used on purpose instead of 255 to avoid numerical differences + # see https://github.com/huggingface/transformers/pull/38540#discussion_r2127165652 + image = image.float() / 256 image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias) - image = image * 255 + image = image * 256 + # torch.where is used on purpose instead of torch.clamp to avoid bug in torch.compile + # see https://github.com/huggingface/transformers/pull/38540#discussion_r2126888471 image = torch.where(image > 255, 255, image) image = torch.where(image < 0, 0, image) image = image.round().to(torch.uint8)