|  | 
|  | 1 | +import keras | 
|  | 2 | + | 
|  | 3 | +from keras_hub.src.api_export import keras_hub_export | 
|  | 4 | +from keras_hub.src.models.backbone import Backbone | 
|  | 5 | + | 
|  | 6 | + | 
|  | 7 | +@keras_hub_export("keras_hub.models.SegFormerBackbone") | 
|  | 8 | +class SegFormerBackbone(Backbone): | 
|  | 9 | +    """A Keras model implementing the SegFormer architecture for semantic segmentation. | 
|  | 10 | +
 | 
|  | 11 | +    This class implements the majority of the SegFormer architecture described in | 
|  | 12 | +    [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers] | 
|  | 13 | +    (https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision] | 
|  | 14 | +    (https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer). | 
|  | 15 | +
 | 
|  | 16 | +    SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and | 
|  | 17 | +    and use a very lightweight all-MLP decoder head. | 
|  | 18 | +
 | 
|  | 19 | +    The MiT encoder uses a hierarchical transformer which outputs features at multiple scales, | 
|  | 20 | +    similar to that of the hierarchical outputs typically associated with CNNs. | 
|  | 21 | +
 | 
|  | 22 | +    Args: | 
|  | 23 | +        image_encoder: `keras.Model`. The backbone network for the model that is | 
|  | 24 | +            used as a feature extractor for the SegFormer encoder. | 
|  | 25 | +            Should be used with the MiT backbone model | 
|  | 26 | +            (`keras_hub.models.MiTBackbone`) which was created | 
|  | 27 | +            specifically for SegFormers. | 
|  | 28 | +        num_classes: int, the number of classes for the detection model, | 
|  | 29 | +            including the background class. | 
|  | 30 | +        projection_filters: int, number of filters in the | 
|  | 31 | +            convolution layer projecting the concatenated features into | 
|  | 32 | +            a segmentation map. Defaults to 256`. | 
|  | 33 | +
 | 
|  | 34 | +    Example: | 
|  | 35 | +
 | 
|  | 36 | +    Using the class with a custom `backbone`: | 
|  | 37 | +
 | 
|  | 38 | +    ```python | 
|  | 39 | +    import keras_hub | 
|  | 40 | +
 | 
|  | 41 | +    backbone = keras_hub.models.MiTBackbone( | 
|  | 42 | +        depths=[2, 2, 2, 2], | 
|  | 43 | +        image_shape=(224, 224, 3), | 
|  | 44 | +        hidden_dims=[32, 64, 160, 256], | 
|  | 45 | +        num_layers=4, | 
|  | 46 | +        blockwise_num_heads=[1, 2, 5, 8], | 
|  | 47 | +        blockwise_sr_ratios=[8, 4, 2, 1], | 
|  | 48 | +        max_drop_path_rate=0.1, | 
|  | 49 | +        patch_sizes=[7, 3, 3, 3], | 
|  | 50 | +        strides=[4, 2, 2, 2], | 
|  | 51 | +    ) | 
|  | 52 | +
 | 
|  | 53 | +    segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256) | 
|  | 54 | +    ``` | 
|  | 55 | +
 | 
|  | 56 | +    Using the class with a preset `backbone`: | 
|  | 57 | +
 | 
|  | 58 | +    ```python | 
|  | 59 | +    import keras_hub | 
|  | 60 | +
 | 
|  | 61 | +    backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") | 
|  | 62 | +    segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256) | 
|  | 63 | +    ``` | 
|  | 64 | +
 | 
|  | 65 | +    """ | 
|  | 66 | + | 
|  | 67 | +    def __init__( | 
|  | 68 | +        self, | 
|  | 69 | +        image_encoder, | 
|  | 70 | +        projection_filters, | 
|  | 71 | +        **kwargs, | 
|  | 72 | +    ): | 
|  | 73 | +        if not isinstance(image_encoder, keras.layers.Layer) or not isinstance( | 
|  | 74 | +            image_encoder, keras.Model | 
|  | 75 | +        ): | 
|  | 76 | +            raise ValueError( | 
|  | 77 | +                "Argument `image_encoder` must be a `keras.layers.Layer` instance " | 
|  | 78 | +                f" or `keras.Model`. Received instead " | 
|  | 79 | +                f"image_encoder={image_encoder} (of type {type(image_encoder)})." | 
|  | 80 | +            ) | 
|  | 81 | + | 
|  | 82 | +        # === Layers === | 
|  | 83 | +        inputs = keras.layers.Input(shape=image_encoder.input.shape[1:]) | 
|  | 84 | + | 
|  | 85 | +        self.feature_extractor = keras.Model( | 
|  | 86 | +            image_encoder.inputs, image_encoder.pyramid_outputs | 
|  | 87 | +        ) | 
|  | 88 | + | 
|  | 89 | +        features = self.feature_extractor(inputs) | 
|  | 90 | +        # Get height and width of level one output | 
|  | 91 | +        _, height, width, _ = features["P1"].shape | 
|  | 92 | + | 
|  | 93 | +        self.mlp_blocks = [] | 
|  | 94 | + | 
|  | 95 | +        for feature_dim, feature in zip(image_encoder.hidden_dims, features): | 
|  | 96 | +            self.mlp_blocks.append( | 
|  | 97 | +                keras.layers.Dense( | 
|  | 98 | +                    projection_filters, name=f"linear_{feature_dim}" | 
|  | 99 | +                ) | 
|  | 100 | +            ) | 
|  | 101 | + | 
|  | 102 | +        self.resizing = keras.layers.Resizing( | 
|  | 103 | +            height, width, interpolation="bilinear" | 
|  | 104 | +        ) | 
|  | 105 | +        self.concat = keras.layers.Concatenate(axis=-1) | 
|  | 106 | +        self.linear_fuse = keras.Sequential( | 
|  | 107 | +            [ | 
|  | 108 | +                keras.layers.Conv2D( | 
|  | 109 | +                    filters=projection_filters, kernel_size=1, use_bias=False | 
|  | 110 | +                ), | 
|  | 111 | +                keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9), | 
|  | 112 | +                keras.layers.Activation("relu"), | 
|  | 113 | +            ] | 
|  | 114 | +        ) | 
|  | 115 | + | 
|  | 116 | +        # === Functional Model === | 
|  | 117 | +        # Project all multi-level outputs onto | 
|  | 118 | +        # the same dimensionality and feature map shape | 
|  | 119 | +        multi_layer_outs = [] | 
|  | 120 | +        for index, (feature_dim, feature) in enumerate( | 
|  | 121 | +            zip(image_encoder.hidden_dims, features) | 
|  | 122 | +        ): | 
|  | 123 | +            out = self.mlp_blocks[index](features[feature]) | 
|  | 124 | +            out = self.resizing(out) | 
|  | 125 | +            multi_layer_outs.append(out) | 
|  | 126 | + | 
|  | 127 | +        # Concat now-equal feature maps | 
|  | 128 | +        concatenated_outs = self.concat(multi_layer_outs[::-1]) | 
|  | 129 | + | 
|  | 130 | +        # Fuse concatenated features into a segmentation map | 
|  | 131 | +        seg = self.linear_fuse(concatenated_outs) | 
|  | 132 | + | 
|  | 133 | +        super().__init__( | 
|  | 134 | +            inputs=inputs, | 
|  | 135 | +            outputs=seg, | 
|  | 136 | +            **kwargs, | 
|  | 137 | +        ) | 
|  | 138 | + | 
|  | 139 | +        # === Config === | 
|  | 140 | +        self.projection_filters = projection_filters | 
|  | 141 | +        self.image_encoder = image_encoder | 
|  | 142 | + | 
|  | 143 | +    def get_config(self): | 
|  | 144 | +        config = super().get_config() | 
|  | 145 | +        config.update( | 
|  | 146 | +            { | 
|  | 147 | +                "projection_filters": self.projection_filters, | 
|  | 148 | +                "image_encoder": keras.saving.serialize_keras_object( | 
|  | 149 | +                    self.image_encoder | 
|  | 150 | +                ), | 
|  | 151 | +            } | 
|  | 152 | +        ) | 
|  | 153 | +        return config | 
|  | 154 | + | 
|  | 155 | +    @classmethod | 
|  | 156 | +    def from_config(cls, config): | 
|  | 157 | +        if "image_encoder" in config and isinstance( | 
|  | 158 | +            config["image_encoder"], dict | 
|  | 159 | +        ): | 
|  | 160 | +            config["image_encoder"] = keras.layers.deserialize( | 
|  | 161 | +                config["image_encoder"] | 
|  | 162 | +            ) | 
|  | 163 | +        return super().from_config(config) | 
0 commit comments