Skip to content

Commit 833154a

Browse files
authored
Two fixes for image resizing in preprocessing (#1927)
1. Properly display when are not resizing the input image in `model.summary()` 2. Allow setting the `image_size` directly on a preprocessing layer. 2. is just to allow a more consistent way to set the input shape across tasks. We now have: ```python text_classifier = keras_hub.models.TextClassifer.from_preset( "bert_base_en", ) text_classifier.preprocessor.sequence_length = 256 image_classifier = keras_hub.models.TextClassifer.from_preset( "bert_base_en", ) image_classifier.preprocessor.image_size = (256, 256) multi_modal_lm = keras_hub.models.CausalLM.from_preset( "some_preset", ) multi_modal_lm.preprocessor.sequence_length = 256 multi_modal_lm.preprocessor.image_size = (256, 256) ```
1 parent 716ef30 commit 833154a

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

keras_hub/src/models/preprocessor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,22 @@ def image_converter(self):
7171
def image_converter(self, value):
7272
self._image_converter = value
7373

74+
@property
75+
def image_size(self):
76+
"""Shortcut to get/set the image size of the image converter."""
77+
if self.image_converter is None:
78+
return None
79+
return self.image_converter.image_size
80+
81+
@image_size.setter
82+
def image_size(self, value):
83+
if self.image_converter is None:
84+
raise ValueError(
85+
"Cannot set `image_size` on preprocessor if `image_converter` "
86+
" is `None`."
87+
)
88+
self.image_converter.image_size = value
89+
7490
def get_config(self):
7591
config = super().get_config()
7692
if self.tokenizer:

keras_hub/src/models/task.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def summary(
280280

281281
def highlight_number(x):
282282
if x is None:
283-
f"[color(45)]{x}[/]"
283+
return f"[color(45)]{x}[/]"
284284
return f"[color(34)]{x:,}[/]" # Format number with commas.
285285

286286
def highlight_symbol(x):
@@ -339,7 +339,10 @@ def add_layer(layer, info):
339339
add_layer(layer, info)
340340
elif isinstance(layer, ImageConverter):
341341
info = "Image size: "
342-
info += highlight_shape(layer.image_size)
342+
image_size = layer.image_size
343+
if image_size is None:
344+
image_size = (None, None)
345+
info += highlight_shape(image_size)
343346
add_layer(layer, info)
344347
elif isinstance(layer, AudioConverter):
345348
info = "Audio shape: "

0 commit comments

Comments
 (0)