Skip to content

Commit 81b989a

Browse files
fix gpu test (#1939)
* fix gpu test * cast input * update dtype * change to resnet preset * remove arg
1 parent 0e8a995 commit 81b989a

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

keras_hub/src/layers/preprocessing/image_converter_test.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
from keras import ops
77

88
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
9-
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
10-
PaliGemmaBackbone,
11-
)
129
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
1310
PaliGemmaImageConverter,
1411
)
12+
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
1513
from keras_hub.src.tests.test_case import TestCase
1614

1715

@@ -86,24 +84,19 @@ def test_from_preset_errors(self):
8684
def test_save_to_preset(self):
8785
save_dir = self.get_temp_dir()
8886
converter = ImageConverter.from_preset(
89-
"pali_gemma_3b_mix_224",
87+
"resnet_50_imagenet",
9088
interpolation="nearest",
9189
)
9290
converter.save_to_preset(save_dir)
9391
# Save a tiny backbone so the preset is valid.
94-
backbone = PaliGemmaBackbone(
95-
vocabulary_size=100,
96-
image_size=224,
97-
num_layers=1,
98-
num_query_heads=1,
99-
num_key_value_heads=1,
100-
hidden_dim=8,
101-
intermediate_dim=16,
102-
head_dim=8,
103-
vit_patch_size=14,
104-
vit_num_heads=1,
105-
vit_hidden_dim=8,
106-
vit_num_layers=1,
92+
backbone = ResNetBackbone(
93+
input_conv_filters=[64],
94+
input_conv_kernel_sizes=[7],
95+
stackwise_num_filters=[64, 64, 64],
96+
stackwise_num_blocks=[2, 2, 2],
97+
stackwise_num_strides=[1, 2, 2],
98+
block_type="basic_block",
99+
use_pre_activation=True,
107100
)
108101
backbone.save_to_preset(save_dir)
109102

keras_hub/src/models/resnet/resnet_backbone.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class ResNetBackbone(FeaturePyramidBackbone):
8080
stackwise_num_strides=[1, 2, 2],
8181
block_type="basic_block",
8282
use_pre_activation=True,
83-
pooling="avg",
8483
)
8584
model(input_data)
8685
```

0 commit comments

Comments
 (0)