diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 6b85148caf..0fe7b300fa 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -34,6 +34,9 @@ from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion from keras_hub.src.layers.preprocessing.random_swap import RandomSwap from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( + DeepLabV3ImageConverter, +) from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 371277465a..1450ddceb3 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -85,6 +85,15 @@ from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( + DeepLabV3ImageSegmenterPreprocessor, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( DenseNetImageClassifier, diff --git a/keras_hub/src/models/deeplab_v3/__init__.py b/keras_hub/src/models/deeplab_v3/__init__.py new file mode 100644 index 0000000000..0a959e1861 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/__init__.py @@ -0,0 +1,7 @@ +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, DeepLabV3Backbone) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py new file mode 100644 index 0000000000..70bf828b01 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py @@ -0,0 +1,196 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( + SpatialPyramidPooling, +) + + +@keras_hub_export("keras_hub.models.DeepLabV3Backbone") +class DeepLabV3Backbone(Backbone): + """DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation. + + This class implements a DeepLabV3 & DeepLabV3Plus architecture as described + in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation]( + https://arxiv.org/abs/1802.02611)(ECCV 2018) + and [Rethinking Atrous Convolution for Semantic Image Segmentation]( + https://arxiv.org/abs/1706.05587)(CVPR 2017) + + Args: + image_encoder: `keras.Model`. An instance that is used as a feature + extractor for the Encoder. Should either be a + `keras_hub.models.Backbone` or a `keras.Model` that implements the + `pyramid_outputs` property with keys "P2", "P3" etc as values. + A somewhat sensible backbone to use in many cases is + the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`. + projection_filters: int. Number of filters in the convolution layer + projecting low-level features from the `image_encoder`. + spatial_pyramid_pooling_key: str. A layer level to extract and perform + `spatial_pyramid_pooling`, one of the key from the `image_encoder` + `pyramid_outputs` property such as "P4", "P5" etc. + upsampling_size: int or tuple of 2 integers. The upsampling factors for + rows and columns of `spatial_pyramid_pooling` layer. + If `low_level_feature_key` is given then `spatial_pyramid_pooling`s + layer resolution should match with the `low_level_feature`s layer + resolution to concatenate both the layers for combined encoder + outputs. + dilation_rates: list. A `list` of integers for parallel dilated conv applied to + `SpatialPyramidPooling`. Usually a + sample choice of rates are `[6, 12, 18]`. + low_level_feature_key: str optional. A layer level to extract the feature + from one of the key from the `image_encoder`s `pyramid_outputs` + property such as "P2", "P3" etc which will be the Decoder block. + Required only when the DeepLabV3Plus architecture needs to be applied. + image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + + Example: + ```python + # Load a trained backbone to extract features from it's `pyramid_outputs`. + image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet") + + model = keras_hub.models.DeepLabV3Backbone( + image_encoder=image_encoder, + projection_filters=48, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P5", + upsampling_size = 8, + dilation_rates = [6, 12, 18] + ) + ``` + """ + + def __init__( + self, + image_encoder, + spatial_pyramid_pooling_key, + upsampling_size, + dilation_rates, + low_level_feature_key=None, + projection_filters=48, + image_shape=(None, None, 3), + **kwargs, + ): + if not isinstance(image_encoder, keras.Model): + raise ValueError( + "Argument `image_encoder` must be a `keras.Model` instance. Received instead " + f"{image_encoder} (of type {type(image_encoder)})." + ) + data_format = keras.config.image_data_format() + channel_axis = -1 if data_format == "channels_last" else 1 + + # === Layers === + inputs = keras.layers.Input(image_shape, name="inputs") + + fpn_model = keras.Model( + image_encoder.inputs, image_encoder.pyramid_outputs + ) + + fpn_outputs = fpn_model(inputs) + + spatial_pyramid_pooling = SpatialPyramidPooling( + dilation_rates=dilation_rates + ) + spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key] + spp_outputs = spatial_pyramid_pooling(spatial_backbone_features) + + encoder_outputs = keras.layers.UpSampling2D( + size=upsampling_size, + interpolation="bilinear", + name="encoder_output_upsampling", + data_format=data_format, + )(spp_outputs) + + if low_level_feature_key: + decoder_feature = fpn_outputs[low_level_feature_key] + low_level_projected_features = apply_low_level_feature_network( + decoder_feature, projection_filters, channel_axis + ) + + encoder_outputs = keras.layers.Concatenate( + axis=channel_axis, name="encoder_decoder_concat" + )([encoder_outputs, low_level_projected_features]) + # upsampling to the original image size + upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // ( + int(upsampling_size[0]) + if isinstance(upsampling_size, tuple) + else upsampling_size + ) + # === Functional Model === + x = keras.layers.Conv2D( + name="segmentation_head_conv", + filters=256, + kernel_size=1, + padding="same", + use_bias=False, + data_format=data_format, + )(encoder_outputs) + x = keras.layers.BatchNormalization( + name="segmentation_head_norm", axis=channel_axis + )(x) + x = keras.layers.ReLU(name="segmentation_head_relu")(x) + x = keras.layers.UpSampling2D( + size=upsampling, + interpolation="bilinear", + data_format=data_format, + name="backbone_output_upsampling", + )(x) + + super().__init__(inputs=inputs, outputs=x, **kwargs) + + # === Config === + self.image_shape = image_shape + self.image_encoder = image_encoder + self.projection_filters = projection_filters + self.upsampling_size = upsampling_size + self.dilation_rates = dilation_rates + self.low_level_feature_key = low_level_feature_key + self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), + "projection_filters": self.projection_filters, + "dilation_rates": self.dilation_rates, + "upsampling_size": self.upsampling_size, + "low_level_feature_key": self.low_level_feature_key, + "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, + "image_shape": self.image_shape, + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config and isinstance( + config["image_encoder"], dict + ): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + return super().from_config(config) + + +def apply_low_level_feature_network( + input_tensor, projection_filters, channel_axis +): + data_format = keras.config.image_data_format() + x = keras.layers.Conv2D( + name="decoder_conv", + filters=projection_filters, + kernel_size=1, + padding="same", + use_bias=False, + data_format=data_format, + )(input_tensor) + + x = keras.layers.BatchNormalization(name="decoder_norm", axis=channel_axis)( + x + ) + x = keras.layers.ReLU(name="decoder_relu")(x) + return x diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py new file mode 100644 index 0000000000..a7b1809085 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -0,0 +1,73 @@ +import keras +import numpy as np +import pytest + +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( + SpatialPyramidPooling, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class DeepLabV3Test(TestCase): + def setUp(self): + self.resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "block_type": "basic_block", + "use_pre_activation": False, + } + self.image_encoder = ResNetBackbone(**self.resnet_kwargs) + self.init_kwargs = { + "image_encoder": self.image_encoder, + "low_level_feature_key": "P2", + "spatial_pyramid_pooling_key": "P4", + "dilation_rates": [6, 12, 18], + "upsampling_size": 4, + "image_shape": (96, 96, 3), + } + self.input_data = np.ones((2, 96, 96, 3), dtype="float32") + + def test_segmentation_basics(self): + self.run_vision_backbone_test( + cls=DeepLabV3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 96, 96, 256), + run_mixed_precision_check=False, + run_quantization_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DeepLabV3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + +class SpatialPyramidPoolingTest(TestCase): + def test_layer_behaviors(self): + self.run_layer_test( + cls=SpatialPyramidPooling, + init_kwargs={ + "dilation_rates": [6, 12, 18], + "activation": "relu", + "num_channels": 256, + "dropout": 0.1, + }, + input_data=keras.random.uniform(shape=(1, 4, 4, 6)), + expected_output_shape=(1, 4, 4, 256), + expected_num_trainable_weights=18, + expected_num_non_trainable_variables=13, + expected_num_non_trainable_weights=12, + run_precision_checks=False, + ) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py new file mode 100644 index 0000000000..5cb0960c83 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py @@ -0,0 +1,10 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) + + +@keras_hub_export("keras_hub.layers.DeepLabV3ImageConverter") +class DeepLabV3ImageConverter(ImageConverter): + backbone_cls = DeepLabV3Backbone diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py new file mode 100644 index 0000000000..c7e4738e37 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py @@ -0,0 +1,16 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( + DeepLabV3ImageConverter, +) +from keras_hub.src.models.image_segmenter_preprocessor import ( + ImageSegmenterPreprocessor, +) + + +@keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenterPreprocessor") +class DeepLabV3ImageSegmenterPreprocessor(ImageSegmenterPreprocessor): + backbone_cls = DeepLabV3Backbone + image_converter_cls = DeepLabV3ImageConverter diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py new file mode 100644 index 0000000000..837e508d2c --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py @@ -0,0 +1,215 @@ +import keras +from keras import ops + + +class SpatialPyramidPooling(keras.layers.Layer): + """Implements the Atrous Spatial Pyramid Pooling. + + Reference for Atrous Spatial Pyramid Pooling [Rethinking Atrous Convolution + for Semantic Image Segmentation](https://arxiv.org/pdf/1706.05587.pdf) and + [Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation](https://arxiv.org/pdf/1802.02611.pdf) + + Args: + dilation_rates: list of ints. The dilation rate for parallel dilated conv. + Usually a sample choice of rates are `[6, 12, 18]`. + num_channels: int. The number of output channels, defaults to `256`. + activation: str. Activation to be used, defaults to `relu`. + dropout: float. The dropout rate of the final projection output after the + activations and batch norm, defaults to `0.0`, which means no dropout is + applied to the output. + + Example: + ```python + inp = keras.layers.Input((384, 384, 3)) + backbone = keras.applications.EfficientNetB0( + input_tensor=inp, + include_top=False) + output = backbone(inp) + output = SpatialPyramidPooling( + dilation_rates=[6, 12, 18])(output) + ``` + """ + + def __init__( + self, + dilation_rates, + num_channels=256, + activation="relu", + dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.dilation_rates = dilation_rates + self.num_channels = num_channels + self.activation = activation + self.dropout = dropout + self.data_format = keras.config.image_data_format() + self.channel_axis = -1 if self.data_format == "channels_last" else 1 + + def build(self, input_shape): + channels = input_shape[self.channel_axis] + + # This is the parallel networks that process the input features with + # different dilation rates. The output from each channel will be merged + # together and feed to the output. + self.aspp_parallel_channels = [] + + # Channel1 with Conv2D and 1x1 kernel size. + conv_sequential = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + name="aspp_conv_1", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="aspp_bn_1" + ), + keras.layers.Activation( + self.activation, name="aspp_activation_1" + ), + ] + ) + conv_sequential.build(input_shape) + self.aspp_parallel_channels.append(conv_sequential) + + # Channel 2 and afterwards are based on self.dilation_rates, and each of + # them will have conv2D with 3x3 kernel size. + for i, dilation_rate in enumerate(self.dilation_rates): + conv_sequential = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(3, 3), + padding="same", + dilation_rate=dilation_rate, + use_bias=False, + data_format=self.data_format, + name=f"aspp_conv_{i+2}", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name=f"aspp_bn_{i+2}" + ), + keras.layers.Activation( + self.activation, name=f"aspp_activation_{i+2}" + ), + ] + ) + conv_sequential.build(input_shape) + self.aspp_parallel_channels.append(conv_sequential) + + # Last channel is the global average pooling with conv2D 1x1 kernel. + if self.channel_axis == -1: + reshape = keras.layers.Reshape((1, 1, channels), name="reshape") + else: + reshape = keras.layers.Reshape((channels, 1, 1), name="reshape") + pool_sequential = keras.Sequential( + [ + keras.layers.GlobalAveragePooling2D( + data_format=self.data_format, name="average_pooling" + ), + reshape, + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + name="conv_pooling", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="bn_pooling" + ), + keras.layers.Activation( + self.activation, name="activation_pooling" + ), + ] + ) + pool_sequential.build(input_shape) + self.aspp_parallel_channels.append(pool_sequential) + + # Final projection layers + projection = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + name="conv_projection", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="bn_projection" + ), + keras.layers.Activation( + self.activation, name="activation_projection" + ), + keras.layers.Dropout(rate=self.dropout, name="dropout"), + ], + ) + projection_input_channels = ( + 2 + len(self.dilation_rates) + ) * self.num_channels + if self.data_format == "channels_first": + projection.build( + (input_shape[0],) + + (projection_input_channels,) + + (input_shape[2:]) + ) + else: + projection.build((input_shape[:-1]) + (projection_input_channels,)) + self.projection = projection + self.built = True + + def call(self, inputs): + """Calls the Atrous Spatial Pyramid Pooling layer on an input. + + Args: + inputs: A tensor of shape [batch, height, width, channels] + + Returns: + A tensor of shape [batch, height, width, num_channels] + """ + result = [] + + for channel in self.aspp_parallel_channels: + temp = ops.cast(channel(inputs), inputs.dtype) + result.append(temp) + + image_shape = ops.shape(inputs) + if self.channel_axis == -1: + height, width = image_shape[1], image_shape[2] + else: + height, width = image_shape[2], image_shape[3] + result[-1] = keras.layers.Resizing( + height, + width, + interpolation="bilinear", + data_format=self.data_format, + name="resizing", + )(result[-1]) + + result = ops.concatenate(result, axis=self.channel_axis) + return self.projection(result) + + def compute_output_shape(self, inputs_shape): + if self.data_format == "channels_first": + return tuple( + (inputs_shape[0],) + (self.num_channels,) + (inputs_shape[2:]) + ) + else: + return tuple((inputs_shape[:-1]) + (self.num_channels,)) + + def get_config(self): + config = super().get_config() + config.update( + { + "dilation_rates": self.dilation_rates, + "num_channels": self.num_channels, + "activation": self.activation, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py new file mode 100644 index 0000000000..1b1dde181d --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py @@ -0,0 +1,4 @@ +"""DeepLabV3 preset configurations.""" + +# TODO https://github.com/keras-team/keras-hub/issues/1896, +backbone_presets = {} diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py new file mode 100644 index 0000000000..7f0f71718c --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py @@ -0,0 +1,109 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( + DeepLabV3ImageSegmenterPreprocessor, +) +from keras_hub.src.models.image_segmenter import ImageSegmenter + + +@keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenter") +class DeepLabV3ImageSegmenter(ImageSegmenter): + """DeepLabV3 and DeeplabV3 and DeeplabV3Plus segmentation task. + + Args: + backbone: A `keras_hub.models.DeepLabV3` instance. + num_classes: int. The number of classes for the detection model. Note + that the `num_classes` contains the background class, and the + classes from the data should be represented by integers with range + `[0, num_classes]`. + activation: str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `None`. + preprocessor: A `keras_hub.models.DeepLabV3ImageSegmenterPreprocessor` + or `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Example: + Load a DeepLabV3 preset with all the 21 class, pretrained segmentation head. + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_resnet50_pascalvoc", + ) + segmenter.predict(images) + ``` + + Specify `num_classes` to load randomly initialized segmentation head. + ```python + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_resnet50_pascalvoc", + num_classes=2, + ) + segmenter.fit(images, labels, epochs=3) + segmenter.predict(images) # Trained 2 class segmentation. + ``` + Load DeepLabv3+ presets a extension of DeepLabv3 by adding a simple yet + effective decoder module to refine the segmentation results especially + along object boundaries. + ```python + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_plus_resnet50_pascalvoc", + ) + segmenter.predict(images) + ``` + """ + + backbone_cls = DeepLabV3Backbone + preprocessor_cls = DeepLabV3ImageSegmenterPreprocessor + + def __init__( + self, + backbone, + num_classes, + activation=None, + preprocessor=None, + **kwargs, + ): + data_format = keras.config.image_data_format() + # === Layers === + self.output_conv = keras.layers.Conv2D( + name="segmentation_output", + filters=num_classes, + kernel_size=1, + use_bias=False, + padding="same", + activation=activation, + data_format=data_format, + ) + + # === Functional Model === + inputs = backbone.input + x = backbone(inputs) + outputs = self.output_conv(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.backbone = backbone + self.num_classes = num_classes + self.activation = activation + self.preprocessor = preprocessor + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py new file mode 100644 index 0000000000..d8285a0e44 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest + +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( + DeepLabV3ImageConverter, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( + DeepLabV3ImageSegmenterPreprocessor, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class DeepLabV3ImageSegmenterTest(TestCase): + def setUp(self): + self.resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "block_type": "basic_block", + "use_pre_activation": False, + } + self.image_encoder = ResNetBackbone(**self.resnet_kwargs) + self.deeplab_backbone = DeepLabV3Backbone( + image_encoder=self.image_encoder, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P4", + dilation_rates=[6, 12, 18], + upsampling_size=4, + ) + image_converter = DeepLabV3ImageConverter(image_size=(16, 16)) + self.preprocessor = DeepLabV3ImageSegmenterPreprocessor( + image_converter=image_converter, + resize_output_mask=True, + ) + self.init_kwargs = { + "backbone": self.deeplab_backbone, + "num_classes": 2, + "activation": "softmax", + "preprocessor": self.preprocessor, + } + self.images = np.ones((2, 96, 96, 3), dtype="float32") + self.labels = np.zeros((2, 96, 96, 2), dtype="float32") + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + self.run_task_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + batch_size=2, + expected_output_shape=(2, 16, 16, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/image_segmenter_preprocessor.py b/keras_hub/src/models/image_segmenter_preprocessor.py index 5b1f6be1e2..e3f29bd5b9 100644 --- a/keras_hub/src/models/image_segmenter_preprocessor.py +++ b/keras_hub/src/models/image_segmenter_preprocessor.py @@ -19,9 +19,11 @@ class ImageSegmenterPreprocessor(Preprocessor): - `x`: The first input, should always be included. It can be an image or a batch of images. - - `y`: (Optional) Usually the segmentation mask(s), will be passed through - unaltered. + - `y`: (Optional) Usually the segmentation mask(s), if `resize_output_mask` + is set to `True` this will be resized to input image shape else will be + passed through unaltered. - `sample_weight`: (Optional) Will be passed through unaltered. + - `resize_output_mask` bool: If set to `True` the output mask will be resized to the same size as the input image. Defaults to `False`. The layer will output either `x`, an `(x, y)` tuple if labels were provided, or an `(x, y, sample_weight)` tuple if labels and sample weight were @@ -29,7 +31,7 @@ class ImageSegmenterPreprocessor(Preprocessor): been applied. All `ImageSegmenterPreprocessor` tasks include a `from_preset()` - constructor which can be used to load a pre-trained config and vocabularies. + constructor which can be used to load a pre-trained config. You can call the `from_preset()` constructor directly on this base class, in which case the correct class for your model will be automatically instantiated. @@ -49,7 +51,8 @@ class ImageSegmenterPreprocessor(Preprocessor): x, y = preprocessor(x, y) # Resize a batch of images and masks. - x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [np.ones((512, 512, 1)), np.zeros((512, 512, 1))] + x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], + [np.ones((512, 512, 1)), np.zeros((512, 512, 1))] x, y = preprocessor(x, y) # Use a `tf.data.Dataset`. @@ -61,13 +64,35 @@ class ImageSegmenterPreprocessor(Preprocessor): def __init__( self, image_converter=None, + resize_output_mask=False, **kwargs, ): super().__init__(**kwargs) self.image_converter = image_converter + self.resize_output_mask = resize_output_mask @preprocessing_function def call(self, x, y=None, sample_weight=None): if self.image_converter: x = self.image_converter(x) + + if y is not None and self.image_converter and self.resize_output_mask: + + y = keras.layers.Resizing( + height=( + self.image_converter.image_size[0] + if self.image_converter.image_size + else None + ), + width=( + self.image_converter.image_size[1] + if self.image_converter.image_size + else None + ), + crop_to_aspect_ratio=self.image_converter.crop_to_aspect_ratio, + interpolation="nearest", + data_format=self.image_converter.data_format, + dtype=self.dtype_policy, + name="mask_resizing", + )(y) return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)