diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index 3dddcdd148a4..204755720837 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -621,6 +621,7 @@ def preprocess( """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index e050ec3f31de..d4ae02e0bb15 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -17,6 +17,7 @@ """ import math +import os import sys from typing import Iterable, List, Union @@ -34,6 +35,11 @@ ProcessorMixin, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +from ..auto import AutoImageProcessor + + +logger = logging.get_logger(__name__) class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False): @@ -96,7 +102,7 @@ def __init__( chat_template=None, image_token="", video_token="