Skip to content

Commit f56d20e

Browse files
divyashreepathihalliushareng
authored andcommitted
Update PaliGemma to remove include_rescaling arg (keras-team#1917)
* update PaliGemma * update conversion script * fix GPU tests
1 parent 0f0dd0c commit f56d20e

File tree

4 files changed

+25
-41
lines changed

4 files changed

+25
-41
lines changed

keras_hub/src/models/pali_gemma/pali_gemma_backbone.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ class PaliGemmaBackbone(Backbone):
6161
vit_classifier_activation: activation function. The activation that
6262
is used for final output classification in the vision transformer.
6363
vit_name: string. The name used for vision transformer layers.
64-
include_rescaling: bool. If true, the image input will be rescaled from
65-
the range `[0, 255]`, to the range `[0, 1]`.
6664
layer_norm_epsilon: float. The epsilon value user for every layer norm
6765
in all transformer blocks.
6866
dropout: float. Dropout probability for the Transformer decoder blocks.
@@ -121,7 +119,6 @@ def __init__(
121119
vit_pooling=None,
122120
vit_classifier_activation=None,
123121
vit_name=None,
124-
include_rescaling=True,
125122
layer_norm_epsilon=1e-6,
126123
dropout=0,
127124
dtype=None,
@@ -145,7 +142,6 @@ def __init__(
145142
vit_intermediate_dim = vit_intermediate_dim or 4304
146143
self.vit_encoder = PaliGemmaVit(
147144
image_size=image_size,
148-
include_rescaling=include_rescaling,
149145
patch_size=vit_patch_size,
150146
num_heads=vit_num_heads,
151147
hidden_dim=vit_hidden_dim,
@@ -215,7 +211,6 @@ def __init__(
215211
# === Config ===
216212
self.vocabulary_size = vocabulary_size
217213
self.image_size = image_size
218-
self.include_rescaling = include_rescaling
219214
self.num_layers = num_layers
220215
self.num_query_heads = num_query_heads
221216
self.num_key_value_heads = num_key_value_heads
@@ -242,7 +237,6 @@ def get_config(self):
242237
{
243238
"vocabulary_size": self.vocabulary_size,
244239
"image_size": self.image_size,
245-
"include_rescaling": self.include_rescaling,
246240
"num_layers": self.num_layers,
247241
"num_query_heads": self.num_query_heads,
248242
"num_key_value_heads": self.num_key_value_heads,

keras_hub/src/models/pali_gemma/pali_gemma_vit.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,6 @@ class PaliGemmaVit(keras.Model):
410410
Args:
411411
image_size: int. The height/width of the image. Both height and width is
412412
expected to be the same.
413-
include_rescaling: bool. If true, the image input will be rescaled from
414-
the range `[0, 255]`, to the range `[0, 1]`.
415413
patch_size: int. The size of each square patch in the input image.
416414
num_heads: int. The number of attention heads for the vision(image)
417415
transformer encoder.
@@ -452,7 +450,6 @@ def __init__(
452450
num_layers,
453451
intermediate_dim,
454452
num_classes,
455-
include_rescaling=True,
456453
pooling=None,
457454
classifier_activation=None,
458455
dtype=None,
@@ -463,14 +460,6 @@ def __init__(
463460
shape=(image_size, image_size, 3), name="images"
464461
)
465462
x = image_input # Intermediate result.
466-
# TODO we have moved this rescaling to preprocessing layers for most
467-
# models. We should consider removing it here, though it would break
468-
# compatibility.
469-
if include_rescaling:
470-
rescaling = keras.layers.Rescaling(
471-
scale=1.0 / 127.5, offset=-1.0, name="rescaling"
472-
)
473-
x = rescaling(image_input)
474463
x = PaliGemmaVitEncoder(
475464
hidden_dim=hidden_dim,
476465
num_layers=num_layers,
@@ -520,7 +509,6 @@ def __init__(
520509
self.pooling = pooling
521510
self.num_classes = num_classes
522511
self.image_size = image_size
523-
self.include_rescaling = include_rescaling
524512
self.patch_size = patch_size
525513
self.classifier_activation = keras.activations.get(
526514
classifier_activation
@@ -549,7 +537,6 @@ def get_config(self):
549537
self.classifier_activation
550538
),
551539
"image_size": self.image_size,
552-
"include_rescaling": self.include_rescaling,
553540
"patch_size": self.patch_size,
554541
}
555542
)

keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,6 @@ def test_vit_encoder(self):
3030
output.shape, (batch_size, intermediate_dim, hidden_dim)
3131
)
3232

33-
def test_vit_rescaling(self):
34-
vit_encoder = PaliGemmaVit(
35-
image_size=16,
36-
patch_size=4,
37-
hidden_dim=8,
38-
num_layers=2,
39-
num_heads=2,
40-
intermediate_dim=16,
41-
num_classes=32,
42-
)
43-
self.assertIsNotNone(vit_encoder.get_layer("rescaling"))
44-
with self.assertRaises(ValueError):
45-
config = vit_encoder.get_config()
46-
config["include_rescaling"] = False
47-
vit_encoder = PaliGemmaVit.from_config(config)
48-
vit_encoder.get_layer("rescaling")
49-
5033
def test_vision_embeddings(self):
5134
embeddings_layer = PaliGemmaVitEmbeddings(
5235
image_size=16,

tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,24 @@
33

44
import numpy as np
55

6+
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
7+
PaliGemmaBackbone,
8+
)
9+
from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import (
10+
PaliGemmaCausalLM,
11+
)
12+
from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import (
13+
PaliGemmaCausalLMPreprocessor,
14+
)
15+
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
16+
PaliGemmaImageConverter,
17+
)
18+
619
os.environ["KERAS_BACKEND"] = "jax"
720

821
import keras # noqa: E402
922
from keras import ops # noqa: E402
1023

11-
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( # noqa: E402
12-
PaliGemmaBackbone,
13-
)
14-
1524
# No GPU for conversion, makes memory management easier.
1625
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
1726

@@ -301,7 +310,18 @@ def main(args):
301310
"vit_hidden_dim": 1152,
302311
"image_size": args.image_size,
303312
}
304-
keras_model = PaliGemmaBackbone(**pali_gemma_backbone_config)
313+
pg_image_converter = PaliGemmaImageConverter(
314+
image_size=(args.image_size, args.image_size),
315+
scale=1.0 / 127.5,
316+
offset=-1,
317+
)
318+
pg_presprocessor = PaliGemmaCausalLMPreprocessor(
319+
image_converter=pg_image_converter
320+
)
321+
pg_backbone = PaliGemmaBackbone(**pali_gemma_backbone_config)
322+
keras_model = PaliGemmaCausalLM(
323+
preprocessor=pg_presprocessor, backbone=pg_backbone
324+
)
305325
# This could be from kaggle or provide local dir path
306326
weights = np.load(args.weights_path)
307327
jax_weights = get_weights_as_numpy(weights, **pali_gemma_backbone_config)

0 commit comments

Comments
 (0)