diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 71afaaf268da..d1621492d667 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -110,11 +110,12 @@ def rescale( if not isinstance(image, np.ndarray): raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") - image = image.astype(dtype) - rescaled_image = image * scale if data_format is not None: rescaled_image = to_channel_dimension_format(rescaled_image, data_format) + + rescaled_image = rescaled_image.astype(dtype) + return rescaled_image