Skip to content

Commit ea08ec6

Browse files
committed
Address comments
1 parent 0dab201 commit ea08ec6

File tree

10 files changed

+52
-300
lines changed

10 files changed

+52
-300
lines changed

keras_nlp/api/models/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,6 @@
183183
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
184184
from keras_nlp.src.models.preprocessor import Preprocessor
185185
from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
186-
from keras_nlp.src.models.resnet.resnet_feature_pyramid_backbone import (
187-
ResNetFeaturePyramidBackbone,
188-
)
189186
from keras_nlp.src.models.resnet.resnet_image_classifier import (
190187
ResNetImageClassifier,
191188
)

keras_nlp/src/models/feature_pyramid_backbone.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,26 @@
1919

2020
@keras_nlp_export("keras_nlp.models.FeaturePyramidBackbone")
2121
class FeaturePyramidBackbone(Backbone):
22+
"""A backbone with feature pyramid outputs.
23+
24+
`FeaturePyramidBackbone` extends `Backbone` with a single `pyramid_outputs`
25+
property for accessing the feature pyramid outputs of the model. Subclassers
26+
should set the `pyramid_outputs` property during the model constructor.
27+
28+
Example:
29+
30+
```python
31+
input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3))
32+
33+
# Convert to feature pyramid output format using ResNet.
34+
backbone = ResNetBackbone.from_preset("resnet50")
35+
model = keras.Model(
36+
inputs=backbone.inputs, outputs=backbone.pyramid_outputs
37+
)
38+
model(input_data) # A dict containing the keys ["P2", "P3", "P4", "P5"]
39+
```
40+
"""
41+
2242
@property
2343
def pyramid_outputs(self):
2444
"""A dict for feature pyramid outputs.

keras_nlp/src/models/resnet/resnet_backbone_test.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616
from absl.testing import parameterized
17+
from keras import models
1718
from keras import ops
1819

1920
from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
@@ -29,8 +30,8 @@ def setUp(self):
2930
"input_image_shape": (None, None, 3),
3031
"pooling": "avg",
3132
}
32-
self.input_size = (16, 16)
33-
self.input_data = ops.ones((2, 16, 16, 3))
33+
self.input_size = 64
34+
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))
3435

3536
@parameterized.named_parameters(
3637
("v1_basic", False, "basic_block"),
@@ -52,6 +53,24 @@ def test_backbone_basics(self, use_pre_activation, block_type):
5253
),
5354
)
5455

56+
def test_pyramid_output_format(self):
57+
init_kwargs = self.init_kwargs.copy()
58+
init_kwargs.update(
59+
{"block_type": "basic_block", "use_pre_activation": False}
60+
)
61+
backbone = ResNetBackbone(**init_kwargs)
62+
model = models.Model(backbone.inputs, backbone.pyramid_outputs)
63+
output_data = model(self.input_data)
64+
65+
self.assertIsInstance(output_data, dict)
66+
self.assertEqual(
67+
list(output_data.keys()), list(backbone.pyramid_outputs.keys())
68+
)
69+
self.assertEqual(list(output_data.keys()), ["P2", "P3", "P4"])
70+
for k, v in output_data.items():
71+
size = self.input_size // (2 ** int(k[1:]))
72+
self.assertEqual(tuple(v.shape[:3]), (2, size, size))
73+
5574
@parameterized.named_parameters(
5675
("v1_basic", False, "basic_block"),
5776
("v1_bottleneck", False, "bottleneck_block"),
@@ -65,7 +84,7 @@ def test_saved_model(self, use_pre_activation, block_type):
6584
{
6685
"block_type": block_type,
6786
"use_pre_activation": use_pre_activation,
68-
"input_image_shape": (16, 16, 3),
87+
"input_image_shape": (None, None, 3),
6988
}
7089
)
7190
self.run_model_saving_test(

keras_nlp/src/models/resnet/resnet_feature_pyramid_backbone.py

Lines changed: 0 additions & 146 deletions
This file was deleted.

keras_nlp/src/models/resnet/resnet_feature_pyramid_backbone_test.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

keras_nlp/src/models/resnet/resnet_image_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class ResNetImageClassifier(ImageClassifier):
2929
the `Dense` layer. Set `activation=None` to return the output
3030
logits. Defaults to `"softmax"`.
3131
head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The
32-
dtype to use for the head's computations and weights.
32+
dtype to use for the classification head's computations and weights.
3333
3434
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
3535
where `x` is a tensor and `y` is a integer from `[0, num_classes)`.

keras_nlp/src/utils/preset_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,9 @@ def check_format(preset):
545545
if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
546546
preset, SAFETENSOR_CONFIG_FILE
547547
):
548-
if TIMM_PREFIX in preset:
548+
# Determine the format by parsing the config file.
549+
config = load_config(preset, HF_CONFIG_FILE)
550+
if "hf://timm" in preset or "architecture" in config:
549551
return "timm"
550552
return "transformers"
551553

keras_nlp/src/utils/timm/convert_resnet.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE
1717
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
1818
from keras_nlp.src.utils.preset_utils import load_config
19-
from keras_nlp.src.utils.timm.safetensor_utils import SafetensorLoader
19+
from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader
2020

2121

2222
def convert_backbone_config(timm_config):
@@ -60,14 +60,11 @@ def convert_backbone_config(timm_config):
6060

6161

6262
def convert_weights(backbone, loader, timm_config):
63-
def transpose_conv2d(x, shape):
64-
return np.transpose(x, (2, 3, 1, 0))
65-
6663
def port_conv2d(keras_layer_name, hf_weight_prefix):
6764
loader.port_weight(
6865
backbone.get_layer(keras_layer_name).kernel,
6966
hf_weight_key=f"{hf_weight_prefix}.weight",
70-
hook_fn=transpose_conv2d,
67+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
7168
)
7269

7370
def port_batch_normalization(keras_layer_name, hf_weight_prefix):
@@ -168,6 +165,7 @@ def load_resnet_backbone(cls, preset, load_weights, **kwargs):
168165
backbone = cls(**keras_config, **kwargs)
169166
if load_weights:
170167
jax_memory_cleanup(backbone)
171-
with SafetensorLoader(preset) as loader:
168+
# Use prefix="" to avoid using `get_prefixed_key`.
169+
with SafetensorLoader(preset, prefix="") as loader:
172170
convert_weights(backbone, loader, timm_config)
173171
return backbone

0 commit comments

Comments
 (0)