From 4b5e4ab9dc24f19ea277bbbeef1843e53e8a970c Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Fri, 16 Aug 2024 23:08:34 +0000 Subject: [PATCH 1/6] Add MixTransformer --- keras_nlp/api/models/__init__.py | 6 + .../src/models/mix_transformer/__init__.py | 13 + .../mix_transformer_backbone.py | 154 +++++++++ .../mix_transformer_backbone_test.py | 66 ++++ .../mix_transformer_classifier.py | 133 ++++++++ .../mix_transformer_classifier_test.py | 64 ++++ .../mix_transformer/mix_transformer_layers.py | 311 ++++++++++++++++++ 7 files changed, 747 insertions(+) create mode 100644 keras_nlp/src/models/mix_transformer/__init__.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_layers.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index e079aa7c9e..3bda1dac5a 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -161,6 +161,12 @@ MistralPreprocessor, ) from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( + MixTransformerImageClassifier, +) from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/mix_transformer/__init__.py b/keras_nlp/src/models/mix_transformer/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py new file mode 100644 index 0000000000..fc5d437118 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -0,0 +1,154 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +import numpy as np +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( + HierarchicalTransformerEncoder, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( + OverlappingPatchingAndEmbedding, +) + + +@keras_nlp_export("keras_nlp.models.MiTBackbone") +class MiTBackbone(FeaturePyramidBackbone): + def __init__( + self, + depths, + include_rescaling=True, + input_image_shape=(224, 224, 3), + embedding_dims=None, + **kwargs, + ): + """A Backbone implementing the MixTransformer. + + This architecture to be used as a backbone for the SegFormer + architecture [SegFormer: Simple and Efficient Design for Semantic + Segmentation with Transformers](https://arxiv.org/abs/2105.15203) + [Based on the TensorFlow implementation from DeepVision]( + https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer) + + Args: + depths: the number of transformer encoders to be used per stage in the + network. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. Defaults to `True`. + input_image_shape: optional shape tuple, defaults to (224, 224, 3). + embedding_dims: the embedding dims per hierarchical stage, used as + the levels of the feature pyramid + + Examples: + + Using the class with a `backbone`: + + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + backbone = keras_nlp.models.MiTBackbone.from_preset("mit_b0_imagenet") + + # Evaluate model + model(images) + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=False), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ + drop_path_rate = 0.1 + dpr = [x for x in np.linspace(0.0, drop_path_rate, sum(depths))] + blockwise_num_heads = [1, 2, 5, 8] + blockwise_sr_ratios = [8, 4, 2, 1] + num_stages = 4 + + cur = 0 + patch_embedding_layers = [] + transformer_blocks = [] + layer_norms = [] + + for i in range(num_stages): + patch_embed_layer = OverlappingPatchingAndEmbedding( + project_dim=embedding_dims[i], + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + name=f"patch_and_embed_{i}", + ) + patch_embedding_layers.append(patch_embed_layer) + + transformer_block = [ + HierarchicalTransformerEncoder( + project_dim=embedding_dims[i], + num_heads=blockwise_num_heads[i], + sr_ratio=blockwise_sr_ratios[i], + drop_prob=dpr[cur + k], + name=f"hierarchical_encoder_{i}_{k}", + ) + for k in range(depths[i]) + ] + transformer_blocks.append(transformer_block) + cur += depths[i] + layer_norms.append(keras.layers.LayerNormalization()) + + image_input = keras.layers.Input(shape=input_image_shape) + x = image_input + + if include_rescaling: + x = keras.layers.Rescaling(scale=1 / 255)(x) + + pyramid_outputs = {} + for i in range(num_stages): + # Compute new height/width after the `proj` + # call in `OverlappingPatchingAndEmbedding` + stride = 4 if i == 0 else 2 + new_height, new_width = ( + int(ops.shape(x)[1] / stride), + int(ops.shape(x)[2] / stride), + ) + + x = patch_embedding_layers[i](x) + for blk in transformer_blocks[i]: + x = blk(x) + x = layer_norms[i](x) + x = keras.layers.Reshape( + (new_height, new_width, -1), name=f"output_level_{i}" + )(x) + pyramid_outputs[f"P{i + 1}"] = x + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + self.depths = depths + self.include_rescaling = include_rescaling + self.input_image_shape = input_image_shape + self.embedding_dims = embedding_dims + self.pyramid_outputs = pyramid_outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "depths": self.depths, + "include_rescaling": self.include_rescaling, + "embedding_dims": self.embedding_dims, + "input_image_shape": self.input_image_shape, + } + ) + return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py new file mode 100644 index 0000000000..22b7cec948 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -0,0 +1,66 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from keras import models + +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MiTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "depths": [2, 2, 2, 2], + "include_rescaling": True, + "input_image_shape": (224, 224, 3), + "embedding_dims": [32, 64, 160, 256], + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 256), + run_quantization_check=False, + run_mixed_precision_check=False, + ) + + def test_pyramid_output_format(self): + init_kwargs = self.init_kwargs + backbone = MiTBackbone(**init_kwargs) + model = models.Model(backbone.inputs, backbone.pyramid_outputs) + output_data = model(self.input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual(list(output_data.keys()), ["P1", "P2", "P3", "P4"]) + for k, v in output_data.items(): + size = self.input_size // (2 ** int(k[1:])) + self.assertEqual(tuple(v.shape[:3]), (2, size, size)) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py new file mode 100644 index 0000000000..2dbf19a596 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) + + +@keras_nlp_export("keras_nlp.models.MixTransformerImageClassifier") +class MixTransformerImageClassifier(ImageClassifier): + """MixTransformerImageClassifier image classifier model. + + Args: + backbone: A `keras_nlp.models.MiTBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.MiTBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.MixTransformerImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = MiTBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py new file mode 100644 index 0000000000..c968c0e532 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( + MixTransformerImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MixTransformerImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = MiTBackbone( + depths=[2, 2, 2, 2], + include_rescaling=True, + input_image_shape=(224, 224, 3), + embedding_dims=[32, 64, 160, 256], + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=MixTransformerImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MixTransformerImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py new file mode 100644 index 0000000000..55e3bfc164 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py @@ -0,0 +1,311 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import keras +from keras import ops +from keras import random + + +class OverlappingPatchingAndEmbedding(keras.layers.Layer): + def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): + """Overlapping Patching and Embedding layer. + + Differs from `PatchingAndEmbedding` in that the patch size does not + affect the sequence length. It's fully derived from the `stride` + parameter. Additionally, no positional embedding is done + as part of the layer - only a projection using a `Conv2D` layer. + [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 + , [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 + and [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 + + Args: + project_dim: integer, the dimensionality of the projection. + Defaults to `32`. + patch_size: integer, the size of the patches to encode. + Defaults to `7`. + stride: integer, the stride to use for the patching before + projection. Defaults to `5`. + """ + super().__init__(**kwargs) + + self.project_dim = project_dim + self.patch_size = patch_size + self.stride = stride + + self.proj = keras.layers.Conv2D( + filters=project_dim, + kernel_size=patch_size, + strides=stride, + padding="same", + ) + self.norm = keras.layers.LayerNormalization() + + def call(self, x): + x = self.proj(x) + # B, H, W, C + shape = x.shape + x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) + x = self.norm(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "project_dim": self.project_dim, + "patch_size": self.patch_size, + "stride": self.stride, + } + ) + return config + + +class HierarchicalTransformerEncoder(keras.layers.Layer): + """Hierarchical transformer encoder block implementation as a Keras Layer. + + The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention` + alternative for computational efficiency, and is meant to be used + within the SegFormer architecture. + [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 + , [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 + and [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 + + Args: + project_dim: integer, the dimensionality of the projection of the + encoder, and output of the `SegFormerMultiheadAttention` layer. + Due to the residual addition the input dimensionality has to be + equal to the output dimensionality. + num_heads: integer, the number of heads for the + `SegFormerMultiheadAttention` layer. + drop_prob: float, the probability of dropping a random + sample using the `DropPath` layer. Defaults to `0.0`. + layer_norm_epsilon: float, the epsilon for + `LayerNormalization` layers. Defaults to `1e-06` + sr_ratio: integer, the ratio to use within + `SegFormerMultiheadAttention`. If set to > 1, a `Conv2D` + layer is used to reduce the length of the sequence. Defaults to `1`. + """ + + def __init__( + self, + project_dim, + num_heads, + sr_ratio=1, + drop_prob=0.0, + layer_norm_epsilon=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.project_dim = project_dim + self.num_heads = num_heads + self.drop_prop = drop_prob + + self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) + self.attn = SegFormerMultiheadAttention( + project_dim, num_heads, sr_ratio + ) + self.drop_path = DropPath(drop_prob) + self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) + self.mlp = self.MixFFN( + channels=project_dim, + mid_channels=int(project_dim * 4), + ) + + def build(self, input_shape): + super().build(input_shape) + self.H = ops.sqrt(ops.cast(input_shape[1], "float32")) + self.W = ops.sqrt(ops.cast(input_shape[2], "float32")) + + def call(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "mlp": keras.saving.serialize_keras_object(self.mlp), + "project_dim": self.project_dim, + "num_heads": self.num_heads, + "drop_prop": self.drop_prop, + } + ) + return config + + class MixFFN(keras.layers.Layer): + def __init__(self, channels, mid_channels): + super().__init__() + self.fc1 = keras.layers.Dense(mid_channels) + self.dwconv = keras.layers.DepthwiseConv2D( + kernel_size=3, + strides=1, + padding="same", + ) + self.fc2 = keras.layers.Dense(channels) + + def call(self, x): + x = self.fc1(x) + shape = ops.shape(x) + H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1])) + B, C = shape[0], shape[2] + x = ops.reshape(x, (B, H, W, C)) + x = self.dwconv(x) + x = ops.reshape(x, (B, -1, C)) + x = ops.nn.gelu(x) + x = self.fc2(x) + return x + + +class SegFormerMultiheadAttention(keras.layers.Layer): + def __init__(self, project_dim, num_heads, sr_ratio): + """Efficient MultiHeadAttention implementation as a Keras layer. + + A huge bottleneck in scaling transformers is the self-attention layer + with an O(n^2) complexity. + + SegFormerMultiheadAttention performs a sequence reduction (SR) operation + with a given ratio, to reduce the sequence length before performing key + and value projections, reducing the O(n^2) complexity to O(n^2/R) where + R is the sequence reduction ratio. + References [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 + , [NVlabs' official implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 + , [@sithu31296's reimplementation](https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py) # noqa: E501 + and [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/efficient_attention.py) # noqa: E501 + + Args: + project_dim: integer, the dimensionality of the projection + of the `SegFormerMultiheadAttention` layer. + num_heads: integer, the number of heads to use in the + attention computation. + sr_ratio: integer, the sequence reduction ratio to perform + on the sequence before key and value projections. + """ + super().__init__() + self.num_heads = num_heads + self.sr_ratio = sr_ratio + self.scale = (project_dim // num_heads) ** -0.5 + self.q = keras.layers.Dense(project_dim) + self.k = keras.layers.Dense(project_dim) + self.v = keras.layers.Dense(project_dim) + self.proj = keras.layers.Dense(project_dim) + + if sr_ratio > 1: + self.sr = keras.layers.Conv2D( + filters=project_dim, + kernel_size=sr_ratio, + strides=sr_ratio, + padding="same", + ) + self.norm = keras.layers.LayerNormalization() + + def call(self, x): + input_shape = ops.shape(x) + H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1])) + B, C = input_shape[0], input_shape[2] + + q = self.q(x) + q = ops.reshape( + q, + ( + input_shape[0], + input_shape[1], + self.num_heads, + input_shape[2] // self.num_heads, + ), + ) + q = ops.transpose(q, [0, 2, 1, 3]) + + if self.sr_ratio > 1: + x = ops.reshape( + ops.transpose(x, [0, 2, 1]), + (B, H, W, C), + ) + x = self.sr(x) + x = ops.reshape(x, [input_shape[0], input_shape[2], -1]) + x = ops.transpose(x, [0, 2, 1]) + x = self.norm(x) + + k = self.k(x) + v = self.v(x) + + k = ops.transpose( + ops.reshape( + k, + [B, -1, self.num_heads, C // self.num_heads], + ), + [0, 2, 1, 3], + ) + + v = ops.transpose( + ops.reshape( + v, + [B, -1, self.num_heads, C // self.num_heads], + ), + [0, 2, 1, 3], + ) + + attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale + attn = ops.nn.softmax(attn, axis=-1) + + attn = attn @ v + attn = ops.reshape( + ops.transpose(attn, [0, 2, 1, 3]), + [input_shape[0], input_shape[1], input_shape[2]], + ) + + x = self.proj(attn) + return x + + +class DropPath(keras.layers.Layer): + """Implements the DropPath layer. + + DropPath randomly drops samples during + training with a probability of `rate`. Note that this layer drops individual + samples within a batch and not the entire batch. DropPath randomly drops + some individual samples from a batch, whereas StochasticDepth + randomly drops the entire batch. + [FractalNet](https://arxiv.org/abs/1605.07648v4). + + Args: + rate: float, the probability of the residual branch being dropped. + seed: (Optional) integer. Used to create a random seed. + """ + + def __init__(self, rate=0.5, seed=None, **kwargs): + super().__init__(**kwargs) + self.rate = rate + self._seed_val = seed + self.seed = random.SeedGenerator(seed=seed) + + def call(self, x, training=None): + if self.rate == 0.0 or not training: + return x + else: + batch_size = x.shape[0] or ops.shape(x)[0] + drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1) + drop_map = ops.cast( + random.uniform(drop_map_shape, seed=self.seed) > self.rate, + x.dtype, + ) + x = x / (1.0 - self.rate) + x = x * drop_map + return x + + def get_config(self): + config = super().get_config() + config.update({"rate": self.rate, "seed": self._seed_val}) + return config From d7b993a43db5254c3a95970c9e8dc761dda05fa1 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Fri, 16 Aug 2024 23:24:12 +0000 Subject: [PATCH 2/6] fix testcase --- .../src/models/mix_transformer/mix_transformer_backbone_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py index 22b7cec948..c0a823f6c6 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -30,6 +30,7 @@ def setUp(self): "input_image_shape": (224, 224, 3), "embedding_dims": [32, 64, 160, 256], } + self.input_size = 112 self.input_data = np.ones((2, 224, 224, 3), dtype="float32") def test_backbone_basics(self): From df9f65ef83f10087eaebad5d24cdfb560e7e5229 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Mon, 19 Aug 2024 18:39:18 +0000 Subject: [PATCH 3/6] test changes and comments --- .../models/mix_transformer/mix_transformer_backbone.py | 7 +++++-- .../mix_transformer/mix_transformer_backbone_test.py | 8 ++++---- .../mix_transformer/mix_transformer_classifier_test.py | 4 ++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py index fc5d437118..981f334414 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -80,6 +80,7 @@ def __init__( blockwise_sr_ratios = [8, 4, 2, 1] num_stages = 4 + # === Layers === cur = 0 patch_embedding_layers = [] transformer_blocks = [] @@ -107,7 +108,8 @@ def __init__( transformer_blocks.append(transformer_block) cur += depths[i] layer_norms.append(keras.layers.LayerNormalization()) - + + # === Functional Model === image_input = keras.layers.Input(shape=input_image_shape) x = image_input @@ -134,7 +136,8 @@ def __init__( pyramid_outputs[f"P{i + 1}"] = x super().__init__(inputs=image_input, outputs=x, **kwargs) - + + # === Config === self.depths = depths self.include_rescaling = include_rescaling self.input_image_shape = input_image_shape diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py index c0a823f6c6..a01125e0a6 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -27,18 +27,18 @@ def setUp(self): self.init_kwargs = { "depths": [2, 2, 2, 2], "include_rescaling": True, - "input_image_shape": (224, 224, 3), + "input_image_shape": (64, 64, 3), "embedding_dims": [32, 64, 160, 256], } - self.input_size = 112 - self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + self.input_size = 32 + self.input_data = np.ones((2, 64, 64, 3), dtype="float32") def test_backbone_basics(self): self.run_backbone_test( cls=MiTBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 7, 7, 256), + expected_output_shape=(2, 2, 2, 256), run_quantization_check=False, run_mixed_precision_check=False, ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py index c968c0e532..0b2a909403 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py @@ -26,12 +26,12 @@ class MixTransformerImageClassifierTest(TestCase): def setUp(self): # Setup model. - self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.images = np.ones((2, 64, 64, 3), dtype="float32") self.labels = [0, 3] self.backbone = MiTBackbone( depths=[2, 2, 2, 2], include_rescaling=True, - input_image_shape=(224, 224, 3), + input_image_shape=(64, 64, 3), embedding_dims=[32, 64, 160, 256], ) self.init_kwargs = { From c228eaae58a9259d06bc470f6c10f0b9e0daf6fd Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Mon, 19 Aug 2024 18:58:56 +0000 Subject: [PATCH 4/6] lint fix --- .../src/models/mix_transformer/mix_transformer_backbone.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py index 981f334414..a56443ff38 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -108,7 +108,7 @@ def __init__( transformer_blocks.append(transformer_block) cur += depths[i] layer_norms.append(keras.layers.LayerNormalization()) - + # === Functional Model === image_input = keras.layers.Input(shape=input_image_shape) x = image_input @@ -136,7 +136,7 @@ def __init__( pyramid_outputs[f"P{i + 1}"] = x super().__init__(inputs=image_input, outputs=x, **kwargs) - + # === Config === self.depths = depths self.include_rescaling = include_rescaling From 3888b54f98f918a8f0f02c7c4e2940c9d6672abb Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 20 Aug 2024 00:22:11 +0000 Subject: [PATCH 5/6] update config list --- keras_nlp/api/models/__init__.py | 2 +- .../mix_transformer_backbone.py | 70 +++++++++++++------ .../mix_transformer_backbone_test.py | 10 ++- .../mix_transformer_classifier.py | 14 ++-- .../mix_transformer_classifier_test.py | 18 +++-- .../mix_transformer/mix_transformer_layers.py | 61 +++++++--------- 6 files changed, 100 insertions(+), 75 deletions(-) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 3bda1dac5a..f352a01b92 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -165,7 +165,7 @@ MiTBackbone, ) from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( - MixTransformerImageClassifier, + MiTImageClassifier, ) from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py index a56443ff38..2cfe7f6761 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -30,9 +30,15 @@ class MiTBackbone(FeaturePyramidBackbone): def __init__( self, depths, + num_layers, + blockwise_num_heads, + blockwise_sr_ratios, + end_value, + patch_sizes, + strides, include_rescaling=True, - input_image_shape=(224, 224, 3), - embedding_dims=None, + image_shape=(224, 224, 3), + hidden_dims=None, **kwargs, ): """A Backbone implementing the MixTransformer. @@ -44,14 +50,24 @@ def __init__( https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer) Args: - depths: the number of transformer encoders to be used per stage in the + depths: The number of transformer encoders to be used per layer in the network. + num_layers: int. The number of Transformer layers. + blockwise_num_heads: list of integers, the number of heads to use + in the attention computation for each layer. + blockwise_sr_ratios: list of integers, the sequence reduction + ratio to perform for each layer on the sequence before key and + value projections. If set to > 1, a `Conv2D` layer is used to + reduce the length of the sequence. + end_value: The end value of the sequence. include_rescaling: bool, whether to rescale the inputs. If set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer. Defaults to `True`. - input_image_shape: optional shape tuple, defaults to (224, 224, 3). - embedding_dims: the embedding dims per hierarchical stage, used as - the levels of the feature pyramid + image_shape: optional shape tuple, defaults to (224, 224, 3). + hidden_dims: the embedding dims per hierarchical layer, used as + the levels of the feature pyramid. + patch_sizes: list of integers, the patch_size to apply for each layer. + strides: list of integers, stride to apply for each layer. Examples: @@ -74,11 +90,7 @@ def __init__( model.fit(images, labels, epochs=3) ``` """ - drop_path_rate = 0.1 - dpr = [x for x in np.linspace(0.0, drop_path_rate, sum(depths))] - blockwise_num_heads = [1, 2, 5, 8] - blockwise_sr_ratios = [8, 4, 2, 1] - num_stages = 4 + dpr = [x for x in np.linspace(0.0, end_value, sum(depths))] # === Layers === cur = 0 @@ -86,18 +98,18 @@ def __init__( transformer_blocks = [] layer_norms = [] - for i in range(num_stages): + for i in range(num_layers): patch_embed_layer = OverlappingPatchingAndEmbedding( - project_dim=embedding_dims[i], - patch_size=7 if i == 0 else 3, - stride=4 if i == 0 else 2, + project_dim=hidden_dims[i], + patch_size=patch_sizes[i], + stride=strides[i], name=f"patch_and_embed_{i}", ) patch_embedding_layers.append(patch_embed_layer) transformer_block = [ HierarchicalTransformerEncoder( - project_dim=embedding_dims[i], + project_dim=hidden_dims[i], num_heads=blockwise_num_heads[i], sr_ratio=blockwise_sr_ratios[i], drop_prob=dpr[cur + k], @@ -110,17 +122,17 @@ def __init__( layer_norms.append(keras.layers.LayerNormalization()) # === Functional Model === - image_input = keras.layers.Input(shape=input_image_shape) + image_input = keras.layers.Input(shape=image_shape) x = image_input if include_rescaling: x = keras.layers.Rescaling(scale=1 / 255)(x) pyramid_outputs = {} - for i in range(num_stages): + for i in range(num_layers): # Compute new height/width after the `proj` # call in `OverlappingPatchingAndEmbedding` - stride = 4 if i == 0 else 2 + stride = strides[i] new_height, new_width = ( int(ops.shape(x)[1] / stride), int(ops.shape(x)[2] / stride), @@ -140,9 +152,15 @@ def __init__( # === Config === self.depths = depths self.include_rescaling = include_rescaling - self.input_image_shape = input_image_shape - self.embedding_dims = embedding_dims + self.image_shape = image_shape + self.hidden_dims = hidden_dims self.pyramid_outputs = pyramid_outputs + self.num_layers = num_layers + self.blockwise_num_heads = blockwise_num_heads + self.blockwise_sr_ratios = blockwise_sr_ratios + self.end_value = end_value + self.patch_sizes = patch_sizes + self.strides = strides def get_config(self): config = super().get_config() @@ -150,8 +168,14 @@ def get_config(self): { "depths": self.depths, "include_rescaling": self.include_rescaling, - "embedding_dims": self.embedding_dims, - "input_image_shape": self.input_image_shape, + "hidden_dims": self.hidden_dims, + "image_shape": self.image_shape, + "num_layers": self.num_layers, + "blockwise_num_heads": self.blockwise_num_heads, + "blockwise_sr_ratios": self.blockwise_sr_ratios, + "end_value": self.end_value, + "patch_sizes": self.patch_sizes, + "strides": self.strides, } ) return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py index a01125e0a6..27ca8ade4a 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -27,8 +27,14 @@ def setUp(self): self.init_kwargs = { "depths": [2, 2, 2, 2], "include_rescaling": True, - "input_image_shape": (64, 64, 3), - "embedding_dims": [32, 64, 160, 256], + "image_shape": (64, 64, 3), + "hidden_dims": [32, 64, 160, 256], + "num_layers": 4, + "blockwise_num_heads": [1, 2, 5, 8], + "blockwise_sr_ratios": [8, 4, 2, 1], + "end_value": 0.1, + "patch_sizes": [7, 3, 3, 3], + "strides": [4, 2, 2, 2], } self.input_size = 32 self.input_data = np.ones((2, 64, 64, 3), dtype="float32") diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py index 2dbf19a596..a9a51b63ba 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py @@ -20,9 +20,9 @@ ) -@keras_nlp_export("keras_nlp.models.MixTransformerImageClassifier") -class MixTransformerImageClassifier(ImageClassifier): - """MixTransformerImageClassifier image classifier model. +@keras_nlp_export("keras_nlp.models.MiTImageClassifier") +class MiTImageClassifier(ImageClassifier): + """MiTImageClassifier image classifier model. Args: backbone: A `keras_nlp.models.MiTBackbone` instance. @@ -42,7 +42,7 @@ class MixTransformerImageClassifier(ImageClassifier): ```python # Load preset and train images = np.ones((2, 224, 224, 3), dtype="float32") - classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset( + classifier = keras_nlp.models.MiTImageClassifier.from_preset( "mit_b0_imagenet") classifier.predict(images) ``` @@ -59,7 +59,7 @@ class MixTransformerImageClassifier(ImageClassifier): Call `fit()` with custom loss, optimizer and backbone. ```python - classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset( + classifier = keras_nlp.models.MiTImageClassifier.from_preset( "mit_b0_imagenet") classifier.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), @@ -78,9 +78,9 @@ class MixTransformerImageClassifier(ImageClassifier): stackwise_depth=[3, 9, 9, 3], include_rescaling=False, block_type="basic_block", - input_image_shape = (224, 224, 3), + image_shape = (224, 224, 3), ) - classifier = keras_nlp.models.MixTransformerImageClassifier( + classifier = keras_nlp.models.MiTImageClassifier( backbone=backbone, num_classes=4, ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py index 0b2a909403..7d264dda11 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py @@ -18,12 +18,12 @@ MiTBackbone, ) from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( - MixTransformerImageClassifier, + MiTImageClassifier, ) from keras_nlp.src.tests.test_case import TestCase -class MixTransformerImageClassifierTest(TestCase): +class MiTImageClassifierTest(TestCase): def setUp(self): # Setup model. self.images = np.ones((2, 64, 64, 3), dtype="float32") @@ -31,8 +31,14 @@ def setUp(self): self.backbone = MiTBackbone( depths=[2, 2, 2, 2], include_rescaling=True, - input_image_shape=(64, 64, 3), - embedding_dims=[32, 64, 160, 256], + image_shape=(64, 64, 3), + hidden_dims=[32, 64, 160, 256], + num_layers=4, + blockwise_num_heads=[1, 2, 5, 8], + blockwise_sr_ratios=[8, 4, 2, 1], + end_value=0.1, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], ) self.init_kwargs = { "backbone": self.backbone, @@ -49,7 +55,7 @@ def test_classifier_basics(self): reason="TODO: enable after preprocessor flow is figured out" ) self.run_task_test( - cls=MixTransformerImageClassifier, + cls=MiTImageClassifier, init_kwargs=self.init_kwargs, train_data=self.train_data, expected_output_shape=(2, 2), @@ -58,7 +64,7 @@ def test_classifier_basics(self): @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( - cls=MixTransformerImageClassifier, + cls=MiTImageClassifier, init_kwargs=self.init_kwargs, input_data=self.images, ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py index 55e3bfc164..53d99fe484 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py @@ -26,9 +26,6 @@ def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): affect the sequence length. It's fully derived from the `stride` parameter. Additionally, no positional embedding is done as part of the layer - only a projection using a `Conv2D` layer. - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 - , [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 - and [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 Args: project_dim: integer, the dimensionality of the projection. @@ -78,9 +75,6 @@ class HierarchicalTransformerEncoder(keras.layers.Layer): The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention` alternative for computational efficiency, and is meant to be used within the SegFormer architecture. - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 - , [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 - and [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 Args: project_dim: integer, the dimensionality of the projection of the @@ -118,7 +112,7 @@ def __init__( ) self.drop_path = DropPath(drop_prob) self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) - self.mlp = self.MixFFN( + self.mlp = MixFFN( channels=project_dim, mid_channels=int(project_dim * 4), ) @@ -145,28 +139,29 @@ def get_config(self): ) return config - class MixFFN(keras.layers.Layer): - def __init__(self, channels, mid_channels): - super().__init__() - self.fc1 = keras.layers.Dense(mid_channels) - self.dwconv = keras.layers.DepthwiseConv2D( - kernel_size=3, - strides=1, - padding="same", - ) - self.fc2 = keras.layers.Dense(channels) - - def call(self, x): - x = self.fc1(x) - shape = ops.shape(x) - H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1])) - B, C = shape[0], shape[2] - x = ops.reshape(x, (B, H, W, C)) - x = self.dwconv(x) - x = ops.reshape(x, (B, -1, C)) - x = ops.nn.gelu(x) - x = self.fc2(x) - return x + +class MixFFN(keras.layers.Layer): + def __init__(self, channels, mid_channels): + super().__init__() + self.fc1 = keras.layers.Dense(mid_channels) + self.dwconv = keras.layers.DepthwiseConv2D( + kernel_size=3, + strides=1, + padding="same", + ) + self.fc2 = keras.layers.Dense(channels) + + def call(self, x): + x = self.fc1(x) + shape = ops.shape(x) + H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1])) + B, C = shape[0], shape[2] + x = ops.reshape(x, (B, H, W, C)) + x = self.dwconv(x) + x = ops.reshape(x, (B, -1, C)) + x = ops.nn.gelu(x) + x = self.fc2(x) + return x class SegFormerMultiheadAttention(keras.layers.Layer): @@ -180,10 +175,6 @@ def __init__(self, project_dim, num_heads, sr_ratio): with a given ratio, to reduce the sequence length before performing key and value projections, reducing the O(n^2) complexity to O(n^2/R) where R is the sequence reduction ratio. - References [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 - , [NVlabs' official implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 - , [@sithu31296's reimplementation](https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py) # noqa: E501 - and [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/efficient_attention.py) # noqa: E501 Args: project_dim: integer, the dimensionality of the projection @@ -275,10 +266,8 @@ class DropPath(keras.layers.Layer): DropPath randomly drops samples during training with a probability of `rate`. Note that this layer drops individual - samples within a batch and not the entire batch. DropPath randomly drops - some individual samples from a batch, whereas StochasticDepth + samples within a batch and not the entire batch, whereas StochasticDepth randomly drops the entire batch. - [FractalNet](https://arxiv.org/abs/1605.07648v4). Args: rate: float, the probability of the residual branch being dropped. From 85bda0821e530eb62788e993284fd0b1c6e81758 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 20 Aug 2024 05:14:53 +0000 Subject: [PATCH 6/6] modify testcase for 2 layers --- .../mix_transformer_backbone_test.py | 28 ++++++++++--------- .../mix_transformer_classifier_test.py | 16 +++++------ 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py index 27ca8ade4a..4f1955297f 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -25,26 +25,28 @@ class MiTBackboneTest(TestCase): def setUp(self): self.init_kwargs = { - "depths": [2, 2, 2, 2], + "depths": [2, 2], "include_rescaling": True, - "image_shape": (64, 64, 3), - "hidden_dims": [32, 64, 160, 256], - "num_layers": 4, - "blockwise_num_heads": [1, 2, 5, 8], - "blockwise_sr_ratios": [8, 4, 2, 1], + "image_shape": (16, 16, 3), + "hidden_dims": [4, 8], + "num_layers": 2, + "blockwise_num_heads": [1, 2], + "blockwise_sr_ratios": [8, 4], "end_value": 0.1, - "patch_sizes": [7, 3, 3, 3], - "strides": [4, 2, 2, 2], + "patch_sizes": [7, 3], + "strides": [4, 2], } - self.input_size = 32 - self.input_data = np.ones((2, 64, 64, 3), dtype="float32") + self.input_size = 16 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) def test_backbone_basics(self): self.run_backbone_test( cls=MiTBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 2, 2, 256), + expected_output_shape=(2, 2, 2, 8), run_quantization_check=False, run_mixed_precision_check=False, ) @@ -59,9 +61,9 @@ def test_pyramid_output_format(self): self.assertEqual( list(output_data.keys()), list(backbone.pyramid_outputs.keys()) ) - self.assertEqual(list(output_data.keys()), ["P1", "P2", "P3", "P4"]) + self.assertEqual(list(output_data.keys()), ["P1", "P2"]) for k, v in output_data.items(): - size = self.input_size // (2 ** int(k[1:])) + size = self.input_size // (2 ** (int(k[1:]) + 1)) self.assertEqual(tuple(v.shape[:3]), (2, size, size)) @pytest.mark.large diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py index 7d264dda11..57b0671be2 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py @@ -26,19 +26,19 @@ class MiTImageClassifierTest(TestCase): def setUp(self): # Setup model. - self.images = np.ones((2, 64, 64, 3), dtype="float32") + self.images = np.ones((2, 16, 16, 3), dtype="float32") self.labels = [0, 3] self.backbone = MiTBackbone( depths=[2, 2, 2, 2], include_rescaling=True, - image_shape=(64, 64, 3), - hidden_dims=[32, 64, 160, 256], - num_layers=4, - blockwise_num_heads=[1, 2, 5, 8], - blockwise_sr_ratios=[8, 4, 2, 1], + image_shape=(16, 16, 3), + hidden_dims=[4, 8], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], end_value=0.1, - patch_sizes=[7, 3, 3, 3], - strides=[4, 2, 2, 2], + patch_sizes=[7, 3], + strides=[4, 2], ) self.init_kwargs = { "backbone": self.backbone,