Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from keras import layers

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.backbone import Backbone
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone


@keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone")
class CSPDarkNetBackbone(Backbone):
class CSPDarkNetBackbone(FeaturePyramidBackbone):
"""This class represents Keras Backbone of CSPDarkNet model.

This class implements a CSPDarkNet backbone as described in
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
stackwise_depth,
include_rescaling,
block_type="basic_block",
image_shape=(224, 224, 3),
image_shape=(None, None, 3),
**kwargs,
):
# === Functional Model ===
Expand All @@ -87,6 +87,8 @@ def __init__(
x = apply_darknet_conv_block(
base_channels, kernel_size=3, strides=1, name="stem_conv"
)(x)

pyramid_outputs = {}
for index, (channels, depth) in enumerate(
zip(stackwise_num_filters, stackwise_depth)
):
Expand All @@ -111,6 +113,7 @@ def __init__(
residual=(index != len(stackwise_depth) - 1),
name=f"dark{index + 2}_csp",
)(x)
pyramid_outputs[f"P{index + 2}"] = x

super().__init__(inputs=image_input, outputs=x, **kwargs)

Expand All @@ -120,6 +123,7 @@ def __init__(
self.include_rescaling = include_rescaling
self.block_type = block_type
self.image_shape = image_shape
self.pyramid_outputs = pyramid_outputs

def get_config(self):
config = super().get_config()
Expand Down
14 changes: 9 additions & 5 deletions keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,24 @@
class CSPDarkNetBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"stackwise_num_filters": [32, 64, 128, 256],
"stackwise_num_filters": [2, 4, 6, 8],
"stackwise_depth": [1, 3, 3, 1],
"include_rescaling": False,
"block_type": "basic_block",
"image_shape": (224, 224, 3),
"image_shape": (32, 32, 3),
}
self.input_data = np.ones((2, 224, 224, 3), dtype="float32")
self.input_size = 32
self.input_data = np.ones(
(2, self.input_size, self.input_size, 3), dtype="float32"
)

def test_backbone_basics(self):
self.run_backbone_test(
self.run_vision_backbone_test(
cls=CSPDarkNetBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 7, 7, 256),
expected_output_shape=(2, 1, 1, 8),
expected_pyramid_output_keys=["P2", "P3", "P4", "P5"],
run_mixed_precision_check=False,
)

Expand Down
13 changes: 8 additions & 5 deletions keras_nlp/src/models/densenet/densenet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.backbone import Backbone
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone

BN_AXIS = 3
BN_EPSILON = 1.001e-5


@keras_nlp_export("keras_nlp.models.DenseNetBackbone")
class DenseNetBackbone(Backbone):
class DenseNetBackbone(FeaturePyramidBackbone):
"""Instantiates the DenseNet architecture.

This class implements a DenseNet backbone as described in
Expand All @@ -35,7 +35,7 @@ class DenseNetBackbone(Backbone):
include_rescaling: bool, whether to rescale the inputs. If set
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
layer. Defaults to `True`.
image_shape: optional shape tuple, defaults to (224, 224, 3).
image_shape: optional shape tuple, defaults to (None, None, 3).
compression_ratio: float, compression rate at transition layers,
defaults to 0.5.
growth_rate: int, number of filters added by each dense block,
Expand All @@ -62,7 +62,7 @@ def __init__(
self,
stackwise_num_repeats,
include_rescaling=True,
image_shape=(224, 224, 3),
image_shape=(None, None, 3),
compression_ratio=0.5,
growth_rate=32,
**kwargs,
Expand All @@ -85,6 +85,7 @@ def __init__(
3, strides=2, padding="same", name="pool1"
)(x)

pyramid_outputs = {}
for stack_index in range(len(stackwise_num_repeats) - 1):
index = stack_index + 2
x = apply_dense_block(
Expand All @@ -93,6 +94,7 @@ def __init__(
growth_rate,
name=f"conv{index}",
)
pyramid_outputs[f"P{index}"] = x
x = apply_transition_block(
x, compression_ratio, name=f"pool{index}"
)
Expand All @@ -103,7 +105,7 @@ def __init__(
growth_rate,
name=f"conv{len(stackwise_num_repeats) + 1}",
)

pyramid_outputs[f"P{len(stackwise_num_repeats) +1}"] = x
x = keras.layers.BatchNormalization(
axis=BN_AXIS, epsilon=BN_EPSILON, name="bn"
)(x)
Expand All @@ -117,6 +119,7 @@ def __init__(
self.compression_ratio = compression_ratio
self.growth_rate = growth_rate
self.image_shape = image_shape
self.pyramid_outputs = pyramid_outputs

def get_config(self):
config = super().get_config()
Expand Down
16 changes: 10 additions & 6 deletions keras_nlp/src/models/densenet/densenet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,24 @@
class DenseNetBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"stackwise_num_repeats": [6, 12, 24, 16],
"stackwise_num_repeats": [2, 4, 6, 4],
"include_rescaling": True,
"compression_ratio": 0.5,
"growth_rate": 32,
"image_shape": (224, 224, 3),
"growth_rate": 2,
"image_shape": (32, 32, 3),
}
self.input_data = np.ones((2, 224, 224, 3), dtype="float32")
self.input_size = 32
self.input_data = np.ones(
(2, self.input_size, self.input_size, 3), dtype="float32"
)

def test_backbone_basics(self):
self.run_backbone_test(
self.run_vision_backbone_test(
cls=DenseNetBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 7, 7, 1024),
expected_output_shape=(2, 1, 1, 24),
expected_pyramid_output_keys=["P2", "P3", "P4", "P5"],
run_mixed_precision_check=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
patch_sizes,
strides,
include_rescaling=True,
image_shape=(224, 224, 3),
image_shape=(None, None, 3),
hidden_dims=None,
**kwargs,
):
Expand All @@ -63,7 +63,7 @@ def __init__(
include_rescaling: bool, whether to rescale the inputs. If set
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
layer. Defaults to `True`.
image_shape: optional shape tuple, defaults to (224, 224, 3).
image_shape: optional shape tuple, defaults to (None, None, 3).
hidden_dims: the embedding dims per hierarchical layer, used as
the levels of the feature pyramid.
patch_sizes: list of integers, the patch_size to apply for each layer.
Expand Down
20 changes: 1 addition & 19 deletions keras_nlp/src/models/resnet/resnet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import pytest
from absl.testing import parameterized
from keras import models
from keras import ops

from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
Expand Down Expand Up @@ -51,26 +50,9 @@ def test_backbone_basics(self, use_pre_activation, block_type):
expected_output_shape=(
(2, 64) if block_type == "basic_block" else (2, 256)
),
expected_pyramid_output_keys=["P2", "P3", "P4"],
)

def test_pyramid_output_format(self):
init_kwargs = self.init_kwargs.copy()
init_kwargs.update(
{"block_type": "basic_block", "use_pre_activation": False}
)
backbone = ResNetBackbone(**init_kwargs)
model = models.Model(backbone.inputs, backbone.pyramid_outputs)
output_data = model(self.input_data)

self.assertIsInstance(output_data, dict)
self.assertEqual(
list(output_data.keys()), list(backbone.pyramid_outputs.keys())
)
self.assertEqual(list(output_data.keys()), ["P2", "P3", "P4"])
for k, v in output_data.items():
size = self.input_size // (2 ** int(k[1:]))
self.assertEqual(tuple(v.shape[:3]), (2, size, size))

@parameterized.named_parameters(
("v1_basic", False, "basic_block"),
("v1_bottleneck", False, "bottleneck_block"),
Expand Down
37 changes: 37 additions & 0 deletions keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def run_vision_backbone_test(
init_kwargs,
input_data,
expected_output_shape,
expected_pyramid_output_keys=None,
variable_length_data=None,
run_mixed_precision_check=True,
run_quantization_check=True,
Expand All @@ -491,6 +492,13 @@ def run_vision_backbone_test(
run_mixed_precision_check=run_mixed_precision_check,
run_quantization_check=run_quantization_check,
)
if expected_pyramid_output_keys:
self.run_pyramid_output_test(
cls=cls,
init_kwargs=init_kwargs,
input_data=input_data,
expected_pyramid_output_keys=expected_pyramid_output_keys,
)

# Check data_format. We assume that `input_data` is in "channels_last"
# format.
Expand All @@ -515,6 +523,13 @@ def run_vision_backbone_test(
run_mixed_precision_check=run_mixed_precision_check,
run_quantization_check=run_quantization_check,
)
if expected_pyramid_output_keys:
self.run_pyramid_output_test(
cls=cls,
init_kwargs=init_kwargs,
input_data=input_data,
expected_pyramid_output_keys=expected_pyramid_output_keys,
)

# Restore the original `image_data_format`.
keras.config.set_image_data_format(ori_data_format)
Expand Down Expand Up @@ -604,5 +619,27 @@ def compare(actual, expected):

tree.map_structure(compare, output, expected_partial_output)

def run_pyramid_output_test(
self,
cls,
init_kwargs,
input_data,
expected_pyramid_output_keys,
):
"""Run Tests for Feature Pyramid output keys and shape."""
backbone = cls(**init_kwargs)
model = keras.models.Model(backbone.inputs, backbone.pyramid_outputs)
output_data = model(input_data)

self.assertIsInstance(output_data, dict)
self.assertEqual(
list(output_data.keys()), list(backbone.pyramid_outputs.keys())
)
self.assertEqual(list(output_data.keys()), expected_pyramid_output_keys)

for k, v in output_data.items():
size = input_data.shape[1] // (2 ** int(k[1:]))
self.assertEqual(tuple(v.shape[:3]), (2, size, size))

def get_test_data_dir(self):
return str(pathlib.Path(__file__).parent / "test_data")