diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 783cfd5087..aca1e28538 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -50,6 +50,12 @@ from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.src.models.causal_lm import CausalLM from keras_nlp.src.models.classifier import Classifier +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import ( DebertaV3Backbone, ) diff --git a/keras_nlp/src/models/csp_darknet/__init__.py b/keras_nlp/src/models/csp_darknet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py new file mode 100644 index 0000000000..2745f61d01 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py @@ -0,0 +1,410 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone") +class CSPDarkNetBackbone(Backbone): + """This class represents Keras Backbone of CSPDarkNet model. + + This class implements a CSPDarkNet backbone as described in + [CSPNet: A New Backbone that can Enhance Learning Capability of CNN]( + https://arxiv.org/abs/1911.11929). + + Args: + stackwise_num_filters: A list of ints, filter size for each dark + level in the model. + stackwise_depth: A list of ints, the depth for each dark level in the + model. + include_rescaling: boolean. If `True`, rescale the input using + `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to + `True`. + block_type: str. One of `"basic_block"` or `"depthwise_block"`. + Use `"depthwise_block"` for depthwise conv block + `"basic_block"` for basic conv block. + Defaults to "basic_block". + input_image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_nlp.models.CSPDarkNetBackbone.from_preset( + "csp_darknet_tiny_imagenet" + ) + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_nlp.models.CSPDarkNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_depth, + include_rescaling, + block_type="basic_block", + input_image_shape=(224, 224, 3), + **kwargs, + ): + # === Functional Model === + apply_ConvBlock = ( + apply_darknet_conv_block_depthwise + if block_type == "depthwise_block" + else apply_darknet_conv_block + ) + base_channels = stackwise_num_filters[0] // 2 + + image_input = layers.Input(shape=input_image_shape) + x = image_input + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0)(x) + + x = apply_focus(name="stem_focus")(x) + x = apply_darknet_conv_block( + base_channels, kernel_size=3, strides=1, name="stem_conv" + )(x) + for index, (channels, depth) in enumerate( + zip(stackwise_num_filters, stackwise_depth) + ): + x = apply_ConvBlock( + channels, + kernel_size=3, + strides=2, + name=f"dark{index + 2}_conv", + )(x) + + if index == len(stackwise_depth) - 1: + x = apply_spatial_pyramid_pooling_bottleneck( + channels, + hidden_filters=channels // 2, + name=f"dark{index + 2}_spp", + )(x) + + x = apply_cross_stage_partial( + channels, + num_bottlenecks=depth, + block_type="basic_block", + residual=(index != len(stackwise_depth) - 1), + name=f"dark{index + 2}_csp", + )(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_depth = stackwise_depth + self.include_rescaling = include_rescaling + self.block_type = block_type + self.input_image_shape = input_image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_depth": self.stackwise_depth, + "include_rescaling": self.include_rescaling, + "block_type": self.block_type, + "input_image_shape": self.input_image_shape, + } + ) + return config + + +def apply_focus(name=None): + """A block used in CSPDarknet to focus information into channels of the + image. + + If the dimensions of a batch input is (batch_size, width, height, channels), + this layer converts the image into size (batch_size, width/2, height/2, + 4*channels). See [the original discussion on YoloV5 Focus Layer](https://github.com/ultralytics/yolov5/discussions/3181). + + Args: + name: the name for the lambda layer used in the block. + + Returns: + a function that takes an input Tensor representing a Focus layer. + """ + + def apply(x): + return layers.Concatenate(name=name)( + [ + x[..., ::2, ::2, :], + x[..., 1::2, ::2, :], + x[..., ::2, 1::2, :], + x[..., 1::2, 1::2, :], + ], + ) + + return apply + + +def apply_darknet_conv_block( + filters, kernel_size, strides, use_bias=False, activation="silu", name=None +): + """ + The basic conv block used in Darknet. Applies Conv2D followed by a + BatchNorm. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. Can be a single + integer to specify the same value both dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the height and width. Can be a single + integer to the same value both dimensions. + use_bias: Boolean, whether the layer uses a bias vector. + activation: the activation applied after the BatchNorm layer. One of + "silu", "relu" or "leaky_relu", defaults to "silu". + name: the prefix for the layer names used in the block. + """ + if name is None: + name = f"conv_block{keras.backend.get_uid('conv_block')}" + + def apply(inputs): + x = layers.Conv2D( + filters, + kernel_size, + strides, + padding="same", + use_bias=use_bias, + name=name + "_conv", + )(inputs) + + x = layers.BatchNormalization(name=name + "_bn")(x) + + if activation == "silu": + x = layers.Lambda(lambda x: keras.activations.silu(x))(x) + elif activation == "relu": + x = layers.ReLU()(x) + elif activation == "leaky_relu": + x = layers.LeakyReLU(0.1)(x) + + return x + + return apply + + +def apply_darknet_conv_block_depthwise( + filters, kernel_size, strides, activation="silu", name=None +): + """ + The depthwise conv block used in CSPDarknet. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the final convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. Can be a single + integer to specify the same value both dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the height and width. Can be a single + integer to the same value both dimensions. + activation: the activation applied after the final layer. One of "silu", + "relu" or "leaky_relu", defaults to "silu". + name: the prefix for the layer names used in the block. + + """ + if name is None: + name = f"conv_block{keras.backend.get_uid('conv_block')}" + + def apply(inputs): + x = layers.DepthwiseConv2D( + kernel_size, strides, padding="same", use_bias=False + )(inputs) + x = layers.BatchNormalization()(x) + + if activation == "silu": + x = layers.Lambda(lambda x: keras.activations.swish(x))(x) + elif activation == "relu": + x = layers.ReLU()(x) + elif activation == "leaky_relu": + x = layers.LeakyReLU(0.1)(x) + + x = apply_darknet_conv_block( + filters, kernel_size=1, strides=1, activation=activation + )(x) + + return x + + return apply + + +def apply_spatial_pyramid_pooling_bottleneck( + filters, + hidden_filters=None, + kernel_sizes=(5, 9, 13), + activation="silu", + name=None, +): + """ + Spatial pyramid pooling layer used in YOLOv3-SPP + + Args: + filters: Integer, the dimensionality of the output spaces (i.e. the + number of output filters in used the blocks). + hidden_filters: Integer, the dimensionality of the intermediate + bottleneck space (i.e. the number of output filters in the + bottleneck convolution). If None, it will be equal to filters. + Defaults to None. + kernel_sizes: A list or tuple representing all the pool sizes used for + the pooling layers, defaults to (5, 9, 13). + activation: Activation for the conv layers, defaults to "silu". + name: the prefix for the layer names used in the block. + + Returns: + a function that takes an input Tensor representing an + SpatialPyramidPoolingBottleneck. + """ + if name is None: + name = f"spp{keras.backend.get_uid('spp')}" + + if hidden_filters is None: + hidden_filters = filters + + def apply(x): + x = apply_darknet_conv_block( + hidden_filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv1", + )(x) + x = [x] + + for kernel_size in kernel_sizes: + x.append( + layers.MaxPooling2D( + kernel_size, + strides=1, + padding="same", + name=f"{name}_maxpool_{kernel_size}", + )(x[0]) + ) + + x = layers.Concatenate(name=f"{name}_concat")(x) + x = apply_darknet_conv_block( + filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv2", + )(x) + + return x + + return apply + + +def apply_cross_stage_partial( + filters, + num_bottlenecks, + residual=True, + block_type="basic_block", + activation="silu", + name=None, +): + """A block used in Cross Stage Partial Darknet. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the final convolution). + num_bottlenecks: an integer representing the number of blocks added in + the layer bottleneck. + residual: a boolean representing whether the value tensor before the + bottleneck should be added to the output of the bottleneck as a + residual, defaults to True. + block_type: str. One of `"basic_block"` or `"depthwise_block"`. + Use `"depthwise_block"` for depthwise conv block + `"basic_block"` for basic conv block. + Defaults to "basic_block". + activation: the activation applied after the final layer. One of "silu", + "relu" or "leaky_relu", defaults to "silu". + """ + + if name is None: + name = f"cross_stage_partial_{keras.backend.get_uid('cross_stage_partial')}" + + def apply(inputs): + hidden_channels = filters // 2 + ConvBlock = ( + apply_darknet_conv_block_depthwise + if block_type == "basic_block" + else apply_darknet_conv_block + ) + + x1 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv1", + )(inputs) + + x2 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv2", + )(inputs) + + for i in range(num_bottlenecks): + residual_x = x1 + x1 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_bottleneck_{i}_conv1", + )(x1) + x1 = ConvBlock( + hidden_channels, + kernel_size=3, + strides=1, + activation=activation, + name=f"{name}_bottleneck_{i}_conv2", + )(x1) + if residual: + x1 = layers.Add(name=f"{name}_bottleneck_{i}_add")( + [residual_x, x1] + ) + + x = layers.Concatenate(name=f"{name}_concat")([x1, x2]) + x = apply_darknet_conv_block( + filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv3", + )(x) + + return x + + return apply diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py new file mode 100644 index 0000000000..aaad4fe515 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CSPDarkNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_filters": [32, 64, 128, 256], + "stackwise_depth": [1, 3, 3, 1], + "include_rescaling": False, + "block_type": "basic_block", + "input_image_shape": (224, 224, 3), + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=CSPDarkNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 256), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPDarkNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py new file mode 100644 index 0000000000..6b013bdcc0 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.image_classifier import ImageClassifier + + +@keras_nlp_export("keras_nlp.models.CSPDarkNetImageClassifier") +class CSPDarkNetImageClassifier(ImageClassifier): + """CSPDarkNet image classifier task model. + + Args: + backbone: A `keras_nlp.models.CSPDarkNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.CSPDarkNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.CSPDarkNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = CSPDarkNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py new file mode 100644 index 0000000000..a07bb017a3 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py @@ -0,0 +1,65 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CSPDarkNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = CSPDarkNetBackbone( + stackwise_num_filters=[2, 16, 16], + stackwise_depth=[1, 3, 3, 1], + include_rescaling=False, + block_type="basic_block", + input_image_shape=(16, 16, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=CSPDarkNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPDarkNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + )