From cb4fb8e2a41fc5a1f302f06058310c6f03f096f7 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 16 Oct 2024 23:50:49 +0000 Subject: [PATCH] add back default image resizing --- keras_hub/src/utils/timm/preset_loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index e662444409..1524db8530 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -53,10 +53,11 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): def load_image_converter(self, cls, **kwargs): pretrained_cfg = self.config.get("pretrained_cfg", None) - if not pretrained_cfg: + if not pretrained_cfg or "input_size" not in pretrained_cfg: return None # This assumes the same basic setup for all timm preprocessing, We may # need to extend this as we cover more model types. + input_size = pretrained_cfg["input_size"] mean = pretrained_cfg["mean"] std = pretrained_cfg["std"] scale = [1.0 / 255.0 / s for s in std] @@ -65,6 +66,7 @@ def load_image_converter(self, cls, **kwargs): if interpolation not in ("bilinear", "nearest", "bicubic"): interpolation = "bilinear" # Unsupported interpolation type. return cls( + image_size=input_size[1:], scale=scale, offset=offset, interpolation=interpolation,