diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 6f44e0ca08..f389052f8e 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -63,6 +63,7 @@ SegFormerImageConverter, ) from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index ccb47b99a2..313b5e1090 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -330,6 +330,11 @@ from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( VGGImageClassifierPreprocessor, ) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer diff --git a/keras_hub/src/models/vit/__init__.py b/keras_hub/src/models/vit/__init__.py new file mode 100644 index 0000000000..e4b42de07d --- /dev/null +++ b/keras_hub/src/models/vit/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, ViTBackbone) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py new file mode 100644 index 0000000000..c34ab7d498 --- /dev/null +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -0,0 +1,152 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.vit.vit_layers import ViTEncoder +from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.ViTBackbone") +class ViTBackbone(Backbone): + """Vision Transformer (ViT) backbone. + + This backbone implements the Vision Transformer architecture as described in + [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). + It transforms the input image into a sequence of patches, embeds them, and + then processes them through a series of Transformer encoder layers. + + Args: + image_shape: A tuple or list of 3 integers representing the shape of the + input image `(height, width, channels)`, `height` and `width` must + be equal. + patch_size: int. The size of each image patch, the input image will be + divided into patches of shape `(patch_size, patch_size)`. + num_layers: int. The number of transformer encoder layers. + num_heads: int. specifying the number of attention heads in each + Transformer encoder layer. + hidden_dim: int. The dimensionality of the hidden representations. + mlp_dim: int. The dimensionality of the intermediate MLP layer in + each Transformer encoder layer. + dropout_rate: float. The dropout rate for the Transformer encoder + layers. + attention_dropout: float. The dropout rate for the attention mechanism + in each Transformer encoder layer. + layer_norm_epsilon: float. Value used for numerical stability in + layer normalization. + use_mha_bias: bool. Whether to use bias in the multi-head + attention layers. + use_mlp_bias: bool. Whether to use bias in the MLP layers. + data_format: str. `"channels_last"` or `"channels_first"`, specifying + the data format for the input image. If `None`, defaults to + `"channels_last"`. + dtype: The dtype of the layer weights. Defaults to None. + **kwargs: Additional keyword arguments to be passed to the parent + `Backbone` class. + """ + + def __init__( + self, + image_shape, + patch_size, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout_rate=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + use_mha_bias=True, + use_mlp_bias=True, + data_format=None, + dtype=None, + **kwargs, + ): + # === Laters === + data_format = standardize_data_format(data_format) + h_axis, w_axis, channels_axis = ( + (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) + ) + # Check that the input image is well specified. + if image_shape[h_axis] is None or image_shape[w_axis] is None: + raise ValueError( + f"Image shape must have defined height and width. Found `None` " + f"at index {h_axis} (height) or {w_axis} (width). " + f"Image shape: {image_shape}" + ) + if image_shape[h_axis] != image_shape[w_axis]: + raise ValueError( + f"Image height and width must be equal. Found height: " + f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at " + f"indices {h_axis} and {w_axis} respectively. Image shape: " + f"{image_shape}" + ) + + num_channels = image_shape[channels_axis] + + # === Functional Model === + inputs = keras.layers.Input(shape=image_shape) + + x = ViTPatchingAndEmbedding( + image_size=image_shape[h_axis], + patch_size=patch_size, + hidden_dim=hidden_dim, + num_channels=num_channels, + data_format=data_format, + dtype=dtype, + name="vit_patching_and_embedding", + )(inputs) + + output = ViTEncoder( + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout_rate=dropout_rate, + attention_dropout=attention_dropout, + layer_norm_epsilon=layer_norm_epsilon, + use_mha_bias=use_mha_bias, + use_mlp_bias=use_mlp_bias, + dtype=dtype, + name="vit_encoder", + )(x) + + super().__init__( + inputs=inputs, + outputs=output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.image_shape = image_shape + self.patch_size = patch_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias + self.data_format = data_format + + def get_config(self): + config = super().get_config() + config.update( + { + "image_shape": self.image_shape, + "patch_size": self.patch_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + } + ) + return config diff --git a/keras_hub/src/models/vit/vit_backbone_test.py b/keras_hub/src/models/vit/vit_backbone_test.py new file mode 100644 index 0000000000..0ab0b389ca --- /dev/null +++ b/keras_hub/src/models/vit/vit_backbone_test.py @@ -0,0 +1,37 @@ +import pytest +from keras import ops + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.tests.test_case import TestCase + + +class ViTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "image_shape": (28, 28, 3), + "patch_size": 4, + "num_layers": 3, + "hidden_dim": 48, + "num_heads": 6, + "mlp_dim": 48 * 4, + "use_mha_bias": True, + } + self.input_size = 28 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=ViTBackbone, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + expected_output_shape=(2, 50, 48), + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ViTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py new file mode 100644 index 0000000000..6e8746d6b6 --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -0,0 +1,187 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.task import Task +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) + + +@keras_hub_export("keras_hub.models.ViTImageClassifier") +class ViTImageClassifier(ImageClassifier): + """ViT image classification task. + + `ViTImageClassifier` tasks wrap a `keras_hub.models.ViTBackbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + image classification. `ViTImageClassifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + Not that unlike `keras_hub.model.ImageClassifier`, the `ViTImageClassifier` + we pluck out `cls_token` which is first seqence from the backbone. + + Args: + backbone: A `keras_hub.models.ViTBackbone` instance or a `keras.Model`. + num_classes: int. The number of classes to predict. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. + pooling: String specifying the classification strategy. The choice + impacts the dimensionality and nature of the feature vector used for + classification. + `"token"`: A single vector (class token) representing the + overall image features. + `"gap"`: A single vector representing the average features + across the spatial dimensions. + intermediate_dim: Optional dimensionality of the intermediate + representation layer before the final classification layer. + If `None`, the output of the transformer is directly used. + Defaults to `None`. + activation: `None`, str, or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + classifier = keras_hub.models.ViTImageClassifier.from_preset( + "vgg_16_imagenet" + ) + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + classifier = keras_hub.models.VGGImageClassifier.from_preset( + "vit_base_patch16_224" + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_hub.models.VGGImageClassifier.from_preset( + "vit_base_patch16_224" + ) + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + model = keras_hub.models.ViTBackbone( + image_shape = (224, 224, 3), + patch_size=16, + num_layers=6, + num_heads=3, + hidden_dim=768, + mlp_dim=2048 + ) + classifier = keras_hub.models.ViTImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = ViTBackbone + preprocessor_cls = ViTImageClassifierPreprocessor + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + pooling="token", + intermediate_dim=None, + activation=None, + dropout=0.0, + head_dtype=None, + **kwargs, + ): + head_dtype = head_dtype or backbone.dtype_policy + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + if intermediate_dim is not None: + self.intermediate_layer = keras.layers.Dense( + intermediate_dim, activation="tanh", name="pre_logits" + ) + + self.dropout = keras.layers.Dropout( + rate=dropout, + dtype=head_dtype, + name="output_dropout", + ) + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + if pooling == "token": + x = x[:, 0] + elif pooling == "gap": + ndim = len(ops.shape(x)) + x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2) + + if intermediate_dim is not None: + x = self.intermediate_layer(x) + + x = self.dropout(x) + outputs = self.output_dense(x) + + # Skip the parent class functional model. + Task.__init__( + self, + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === config === + self.num_classes = num_classes + self.pooling = pooling + self.intermediate_dim = intermediate_dim + self.activation = activation + self.dropout = dropout + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "pooling": self.pooling, + "intermediate_dim": self.intermediate_dim, + "activation": self.activation, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py b/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py new file mode 100644 index 0000000000..7e50918eb6 --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter + + +@keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor") +class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = ViTBackbone + image_converter_cls = ViTImageConverter diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py new file mode 100644 index 0000000000..29e3d66922 --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter +from keras_hub.src.tests.test_case import TestCase + + +class ViTImageClassifierTest(TestCase): + def setUp(self): + self.images = np.ones((2, 28, 28, 3)) + self.labels = [0, 1] + self.backbone = ViTBackbone( + image_shape=(28, 28, 3), + patch_size=4, + num_layers=3, + num_heads=6, + hidden_dim=48, + mlp_dim=48 * 4, + ) + image_converter = ViTImageConverter( + image_size=(28, 28), + scale=1 / 255.0, + ) + preprocessor = ViTImageClassifierPreprocessor( + image_converter=image_converter + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "preprocessor": preprocessor, + } + self.train_data = (self.images, self.labels) + + def test_classifier_basics(self): + self.run_task_test( + cls=ViTImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + def test_head_dtype(self): + model = ViTImageClassifier(**self.init_kwargs, head_dtype="bfloat16") + self.assertEqual(model.output_dense.compute_dtype, "bfloat16") + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ViTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/vit/vit_image_converter.py b/keras_hub/src/models/vit/vit_image_converter.py new file mode 100644 index 0000000000..b1699640ce --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_converter.py @@ -0,0 +1,73 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras_hub_export("keras_hub.layers.ViTImageConverter") +class ViTImageConverter(ImageConverter): + """Converts images to the format expected by a ViT model. + + This layer performs image normalization using mean and standard deviation values. + By default, it uses the same normalization as the + "google/vit-large-patch16-224" model on Hugging Face: + `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]` + ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)). + These defaults are suitable for models pretrained using this normalization. + + Args: + norm_mean: list or tuple of floats. Mean values for image normalization. + Defaults to `[0.5, 0.5, 0.5]`. + norm_std: list or tuple of floats. Standard deviation values for + image normalization. Defaults to `[0.5, 0.5, 0.5]`. + **kwargs: Additional keyword arguments passed to + `keras_hub.layers.preprocessing.ImageConverter`. + + Examples: + ```python + import keras + import numpy as np + from keras_hub.src.layers import ViTImageConverter + + # Example image (replace with your actual image data) + image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C) + + # Create a ViTImageConverter instance + converter = ViTImageConverter( + image_size=(28,28), + scale=1/255. + ) + # Preprocess the image + preprocessed_image = converter(image) + ``` + """ + + backbone_cls = ViTBackbone + + def __init__( + self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs + ): + super().__init__(**kwargs) + self.norm_mean = norm_mean + self.norm_std = norm_std + + @preprocessing_function + def call(self, inputs): + x = super().call(inputs) + # By default normalize using imagenet mean and std + if self.norm_mean: + x = x - self._expand_non_channel_dims(self.norm_mean, x) + if self.norm_std: + x = x / self._expand_non_channel_dims(self.norm_std, x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "norm_mean": self.norm_mean, + "norm_std": self.norm_std, + } + ) + return config diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py new file mode 100644 index 0000000000..8cdc52ca71 --- /dev/null +++ b/keras_hub/src/models/vit/vit_layers.py @@ -0,0 +1,391 @@ +import keras +from keras import ops + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class MLP(keras.layers.Layer): + """Multi-Layer Perceptron (MLP) block. + + Args: + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_bias: bool. Whether to use bias in the dense layers. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + + def __init__( + self, + hidden_dim, + mlp_dim, + use_bias=True, + dropout_rate=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + # === Config === + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.use_bias = use_bias + self.dropout_rate = dropout_rate + + def build(self, input_shape): + self.dense_1 = keras.layers.Dense( + units=self.mlp_dim, + use_bias=self.use_bias, + activation="gelu", + bias_initializer=( + keras.initializers.RandomNormal(stddev=1e-6) + if self.use_bias + else None + ), + dtype=self.dtype_policy, + name="dense_1", + ) + self.dense_1.build(input_shape) + self.dense_2 = keras.layers.Dense( + units=self.hidden_dim, + use_bias=self.use_bias, + bias_initializer=( + keras.initializers.RandomNormal(stddev=1e-6) + if self.use_bias + else None + ), + dtype=self.dtype_policy, + name="dense_2", + ) + self.dense_2.build((None, None, self.mlp_dim)) + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) + self.built = True + + def call(self, inputs): + x = self.dense_1(inputs) + x = self.dense_2(x) + out = self.dropout(x) + return out + + +class ViTPatchingAndEmbedding(keras.layers.Layer): + """Patches the image and embeds the patches. + + Args: + image_size: int. Size of the input image (height or width). + Assumed to be square. + patch_size: int. Size of each image patch. + hidden_dim: int. Dimensionality of the patch embeddings. + num_channels: int. Number of channels in the input image. Defaults to + `3`. + data_format: str. `"channels_last"` or `"channels_first"`. Defaults to + `None` (which uses `"channels_last"`). + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + + def __init__( + self, + image_size, + patch_size, + hidden_dim, + num_channels=3, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + num_patches = (image_size // patch_size) ** 2 + num_positions = num_patches + 1 + + # === Config === + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.num_channels = num_channels + self.num_patches = num_patches + self.num_positions = num_positions + self.data_format = standardize_data_format(data_format) + + def build(self, input_shape): + self.class_token = self.add_weight( + shape=( + 1, + 1, + self.hidden_dim, + ), + initializer="random_normal", + dtype=self.variable_dtype, + name="class_token", + ) + self.patch_embedding = keras.layers.Conv2D( + filters=self.hidden_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + activation=None, + dtype=self.dtype_policy, + data_format=self.data_format, + name="patch_embedding", + ) + self.patch_embedding.build(input_shape) + self.position_embedding = keras.layers.Embedding( + self.num_positions, + self.hidden_dim, + dtype=self.dtype_policy, + embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02), + name="position_embedding", + ) + self.position_embedding.build((1, self.num_positions)) + self.position_ids = keras.ops.expand_dims( + keras.ops.arange(self.num_positions), axis=0 + ) + self.built = True + + def call(self, inputs): + patch_embeddings = self.patch_embedding(inputs) + if self.data_format == "channels_first": + patch_embeddings = ops.transpose( + patch_embeddings, axes=(0, 2, 3, 1) + ) + embeddings_shape = ops.shape(patch_embeddings) + patch_embeddings = ops.reshape( + patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]] + ) + class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1)) + position_embeddings = self.position_embedding(self.position_ids) + embeddings = ops.concatenate([class_token, patch_embeddings], axis=1) + return ops.add(embeddings, position_embeddings) + + def compute_output_shape(self, input_shape): + return ( + input_shape[0], + self.num_positions, + self.hidden_dim, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "image_size": self.image_size, + "patch_size": self.patch_size, + "hidden_dim": self.hidden_dim, + "num_channels": self.num_channels, + "num_patches": self.num_patches, + "num_positions": self.num_positions, + } + ) + return config + + +class ViTEncoderBlock(keras.layers.Layer): + """Transformer encoder block. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_mha_bias: bool. Whether to use bias in the multi-head attention + layer. Defaults to `True`. + use_mlp_bias: bool. Whether to use bias in the MLP layer. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + attention_dropout: float. Dropout rate for the attention mechanism. + Between 0 and 1. Defaults to `0.0`. + layer_norm_epsilon: float. Small float value for layer normalization + stability. Defaults to `1e-6`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + + def __init__( + self, + num_heads, + hidden_dim, + mlp_dim, + use_mha_bias=True, + use_mlp_bias=True, + dropout_rate=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + + key_dim = hidden_dim // num_heads + + # === Config === + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.key_dim = key_dim + self.mlp_dim = mlp_dim + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, input_shape): + # Attention block + self.layer_norm_1 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + name="ln_1", + dtype=self.dtype_policy, + ) + self.layer_norm_1.build(input_shape) + self.mha = keras.layers.MultiHeadAttention( + num_heads=self.num_heads, + key_dim=self.key_dim, + use_bias=self.use_mha_bias, + dropout=self.attention_dropout, + name="mha", + dtype=self.dtype_policy, + ) + self.mha.build(input_shape, input_shape) + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) + + # MLP block + self.layer_norm_2 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + name="ln_2", + dtype=self.dtype_policy, + ) + self.layer_norm_2.build((None, None, self.hidden_dim)) + self.mlp = MLP( + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + use_bias=self.use_mlp_bias, + name="mlp", + dtype=self.dtype_policy, + ) + self.mlp.build((None, None, self.hidden_dim)) + self.built = True + + def call(self, inputs): + x = self.layer_norm_1(inputs) + x = self.mha(x, x) + x = self.dropout(x) + x = x + inputs + + y = self.layer_norm_2(x) + y = self.mlp(y) + + return x + y + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "key_dim": self.key_dim, + "mlp_dim": self.mlp_dim, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + +class ViTEncoder(keras.layers.Layer): + """Vision Transformer (ViT) encoder. + + Args: + num_layers: int. Number of Transformer encoder blocks. + num_heads: int. Number of attention heads. + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_mha_bias: bool. Whether to use bias in the multi-head attention + layers. Defaults to `True`. + use_mlp_bias: bool. Whether to use bias in the MLP layers. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + attention_dropout: float. Dropout rate for the attention mechanism. + Between 0 and 1. Defaults to `0.0`. + layer_norm_epsilon: float. Small float value for layer normalization + tability. Defaults to `1e-6`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + + def __init__( + self, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + use_mha_bias=True, + use_mlp_bias=True, + dropout_rate=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + + # === config === + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, input_shape): + self.encoder_layers = [] + for i in range(self.num_layers): + encoder_block = ViTEncoderBlock( + num_heads=self.num_heads, + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, + use_mha_bias=self.use_mha_bias, + use_mlp_bias=self.use_mlp_bias, + attention_dropout=self.attention_dropout, + layer_norm_epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name=f"tranformer_block_{i+1}", + ) + encoder_block.build((None, None, self.hidden_dim)) + self.encoder_layers.append(encoder_block) + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) + self.layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="ln", + ) + self.layer_norm.build((None, None, self.hidden_dim)) + self.built = True + + def call(self, inputs): + x = self.dropout(inputs) + for i in range(self.num_layers): + x = self.encoder_layers[i](x) + x = self.layer_norm(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_hub/src/models/vit/vit_presets.py b/keras_hub/src/models/vit/vit_presets.py new file mode 100644 index 0000000000..16a6f694e4 --- /dev/null +++ b/keras_hub/src/models/vit/vit_presets.py @@ -0,0 +1,49 @@ +"""ViT model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "vit_base_patch16_224_imagenet": { + "metadata": { + "description": ( + "ViT-B16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 224x224 " + ), + "params": 85798656, + "path": "vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/1", + }, + "vit_base_patch16_384_imagenet": { + "metadata": { + "description": ( + "ViT-B16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 384x384 " + ), + "params": 86090496, + "path": "vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/1", + }, + "vit_large_patch16_224_imagenet": { + "metadata": { + "description": ( + "ViT-L16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 224x224 " + ), + "params": 303301632, + "path": "vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/1", + }, + "vit_large_patch16_384_imagenet": { + "metadata": { + "description": ( + "ViT-L16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 384x384 " + ), + "params": 303690752, + "path": "vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/1", + }, +} diff --git a/keras_hub/src/utils/transformers/convert_vit.py b/keras_hub/src/utils/transformers/convert_vit.py new file mode 100644 index 0000000000..9ce3b3d3ad --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_vit.py @@ -0,0 +1,151 @@ +import numpy as np + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone + +backbone_cls = ViTBackbone + + +def convert_backbone_config(transformers_config): + image_size = transformers_config["image_size"] + return { + "image_shape": (image_size, image_size, 3), + "patch_size": transformers_config["patch_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_heads": transformers_config["num_attention_heads"], + "hidden_dim": transformers_config["hidden_size"], + "mlp_dim": transformers_config["intermediate_size"], + "dropout_rate": transformers_config["hidden_dropout_prob"], + "attention_dropout": transformers_config[ + "attention_probs_dropout_prob" + ], + "use_mha_bias": transformers_config["qkv_bias"], + } + + +def convert_weights(backbone, loader, transformers_config): + + def port_ln(keras_variable, weight_key): + loader.port_weight(keras_variable.gamma, f"{weight_key}.weight") + loader.port_weight(keras_variable.beta, f"{weight_key}.bias") + + def port_dense(keras_variable, weight_key): + loader.port_weight( + keras_variable.kernel, + f"{weight_key}.weight", + hook_fn=lambda x, _: x.T, + ) + if keras_variable.bias is not None: + loader.port_weight(keras_variable.bias, f"{weight_key}.bias") + + def port_mha(keras_variable, weight_key, num_heads, hidden_dim): + # query + loader.port_weight( + keras_variable.query_dense.kernel, + f"{weight_key}.attention.query.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.query_dense.bias, + f"{weight_key}.attention.query.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # key + loader.port_weight( + keras_variable.key_dense.kernel, + f"{weight_key}.attention.key.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.key_dense.bias, + f"{weight_key}.attention.key.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # value + loader.port_weight( + keras_variable.value_dense.kernel, + f"{weight_key}.attention.value.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.value_dense.bias, + f"{weight_key}.attention.value.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # output + loader.port_weight( + keras_variable.output_dense.kernel, + f"{weight_key}.output.dense.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (num_heads, hidden_dim // num_heads, hidden_dim) + ), + ) + loader.port_weight( + keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias" + ) + + loader.port_weight( + keras_variable=backbone.layers[1].patch_embedding.kernel, + hf_weight_key="vit.embeddings.patch_embeddings.projection.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + loader.port_weight( + backbone.layers[1].patch_embedding.bias, + "vit.embeddings.patch_embeddings.projection.bias", + ) + + loader.port_weight( + backbone.layers[1].class_token, + "vit.embeddings.cls_token", + ) + + loader.port_weight( + backbone.layers[1].position_embedding.embeddings, + "vit.embeddings.position_embeddings", + hook_fn=lambda x, _: x[0], + ) + encoder_layers = backbone.layers[2].encoder_layers + for i, encoder_block in enumerate(encoder_layers): + prefix = "vit.encoder.layer" + num_heads = encoder_block.num_heads + hidden_dim = encoder_block.hidden_dim + + port_mha( + encoder_block.mha, + f"{prefix}.{i}.attention", + num_heads, + hidden_dim, + ) + port_ln(encoder_block.layer_norm_1, f"{prefix}.{i}.layernorm_before") + port_ln(encoder_block.layer_norm_2, f"{prefix}.{i}.layernorm_after") + + port_dense( + encoder_block.mlp.dense_1, f"{prefix}.{i}.intermediate.dense" + ) + port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense") + port_ln(backbone.layers[2].layer_norm, "vit.layernorm") + + +def convert_head(task, loader, transformers_config): + prefix = "classifier." + loader.port_weight( + task.output_dense.kernel, + hf_weight_key=prefix + "weight", + hook_fn=lambda x, _: x.T, + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key=prefix + "bias", + ) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index b285a3c090..a3c46f4cf8 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -1,5 +1,6 @@ """Convert huggingface models to KerasHub.""" +from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup from keras_hub.src.utils.transformers import convert_albert @@ -11,6 +12,7 @@ from keras_hub.src.utils.transformers import convert_llama3 from keras_hub.src.utils.transformers import convert_mistral from keras_hub.src.utils.transformers import convert_pali_gemma +from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -37,6 +39,8 @@ def __init__(self, preset, config): self.converter = convert_mistral elif model_type == "paligemma": self.converter = convert_pali_gemma + elif model_type == "vit": + self.converter = convert_vit else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " @@ -55,6 +59,25 @@ def load_backbone(self, cls, load_weights, **kwargs): self.converter.convert_weights(backbone, loader, self.config) return backbone + def load_task(self, cls, load_weights, load_task_weights, **kwargs): + architecture = self.config["architectures"][0] + if ( + not load_task_weights + or not issubclass(cls, ImageClassifier) + or architecture == "ViTModel" + ): + return super().load_task( + cls, load_weights, load_task_weights, **kwargs + ) + # Support loading the classification head for classifier models. + if architecture == "ViTForImageClassification": + kwargs["num_classes"] = len(self.config["id2label"]) + task = super().load_task(cls, load_weights, load_task_weights, **kwargs) + if load_task_weights: + with SafetensorLoader(self.preset, prefix="") as loader: + self.converter.convert_head(task, loader, self.config) + return task + def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): return self.converter.convert_tokenizer(cls, self.preset, **kwargs) diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py new file mode 100644 index 0000000000..0777535229 --- /dev/null +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -0,0 +1,370 @@ +"""Convert ViT checkpoints. + +export KAGGLE_USERNAME=XXX +export KAGGLE_KEY=XXX + +python tools/checkpoint_conversion/convert_vit_checkpoints.py \ + --preset vit_base_patch16_224 +""" + +import os +import shutil + +import keras +import numpy as np +import torch +from absl import app +from absl import flags +from PIL import Image +from transformers import ViTForImageClassification +from transformers import ViTImageProcessor + +import keras_hub +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "vit_base_patch16_224": "google/vit-base-patch16-224", + "vit_base_patch16_384": "google/vit-base-patch16-384", + "vit_base_patch32_384": "google/vit-base-patch32-384", + "vit_large_patch16_224": "google/vit-large-patch16-224", + "vit_large_patch16_384": "google/vit-large-patch16-384", + "vit_large_patch32_384": "google/vit-large-patch32-384", + "vit_base_patch16_224_in21k": "google/vit-base-patch16-224-in21k", + "vit_base_patch32_224_in21k": "google/vit-base-patch32-224-in21k", + "vit_large_patch16_224_in21k": "google/vit-large-patch16-224-in21k", + "vit_large_patch32_224_in21k": "google/vit-large-patch32-224-in21k", + "vit_huge_patch14_224_in21k": "google/vit-huge-patch14-224-in21k", +} + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + + +def convert_model(hf_model): + config = hf_model.config.to_dict() + image_size = config["image_size"] + backbone = ViTBackbone( + image_shape=(image_size, image_size, 3), + patch_size=config["patch_size"], + num_layers=config["num_hidden_layers"], + num_heads=config["num_attention_heads"], + hidden_dim=config["hidden_size"], + mlp_dim=config["intermediate_size"], + dropout_rate=config["hidden_dropout_prob"], + attention_dropout=config["attention_probs_dropout_prob"], + use_mha_bias=config["qkv_bias"], + ) + + return backbone, config + + +def convert_backbone_weights(backbone, hf_model): + state_dict = hf_model.state_dict() + state_dict.update(hf_model.named_buffers()) + + # Helper functions. + def port_weights(keras_variable, weight_key, hook_fn=None): + torch_tensor = state_dict[weight_key].cpu().numpy() + if hook_fn: + torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape)) + keras_variable.assign(torch_tensor) + + def port_ln(keras_variable, weight_key): + port_weights(keras_variable.gamma, f"{weight_key}.weight") + port_weights(keras_variable.beta, f"{weight_key}.bias") + + def port_dense(keras_variable, weight_key): + port_weights( + keras_variable.kernel, + f"{weight_key}.weight", + hook_fn=lambda x, _: x.T, + ) + if keras_variable.bias is not None: + port_weights(keras_variable.bias, f"{weight_key}.bias") + + def port_mha(keras_variable, weight_key, num_heads, hidden_dim): + # query + port_weights( + keras_variable.query_dense.kernel, + f"{weight_key}.attention.query.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.query_dense.bias, + f"{weight_key}.attention.query.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # key + port_weights( + keras_variable.key_dense.kernel, + f"{weight_key}.attention.key.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.key_dense.bias, + f"{weight_key}.attention.key.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # value + port_weights( + keras_variable.value_dense.kernel, + f"{weight_key}.attention.value.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.value_dense.bias, + f"{weight_key}.attention.value.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # output + port_weights( + keras_variable.output_dense.kernel, + f"{weight_key}.output.dense.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (num_heads, hidden_dim // num_heads, hidden_dim) + ), + ) + port_weights( + keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias" + ) + + port_weights( + backbone.layers[1].patch_embedding.kernel, + "vit.embeddings.patch_embeddings.projection.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + port_weights( + backbone.layers[1].patch_embedding.bias, + "vit.embeddings.patch_embeddings.projection.bias", + ) + + port_weights( + backbone.layers[1].class_token, + "vit.embeddings.cls_token", + ) + + port_weights( + backbone.layers[1].position_embedding.embeddings, + "vit.embeddings.position_embeddings", + hook_fn=lambda x, _: x[0], + ) + encoder_layers = backbone.layers[2].encoder_layers + for i, encoder_block in enumerate(encoder_layers): + prefix = "vit.encoder.layer" + num_heads = encoder_block.num_heads + hidden_dim = encoder_block.hidden_dim + + port_mha( + encoder_block.mha, + f"{prefix}.{i}.attention", + num_heads, + hidden_dim, + ) + port_ln(encoder_block.layer_norm_1, f"{prefix}.{i}.layernorm_before") + port_ln(encoder_block.layer_norm_2, f"{prefix}.{i}.layernorm_after") + + port_dense( + encoder_block.mlp.dense_1, f"{prefix}.{i}.intermediate.dense" + ) + port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense") + + port_ln(backbone.layers[2].layer_norm, "vit.layernorm") + # port_dense(keras_hub_model.output_dense, "classifier") + + +def convert_head_weights(keras_model, hf_model): + state_dict = hf_model.state_dict() + state_dict.update(hf_model.named_buffers()) + + def port_weights(keras_variable, weight_key, hook_fn=None): + torch_tensor = state_dict[weight_key].cpu().numpy() + if hook_fn: + torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape)) + keras_variable.assign(torch_tensor) + + prefix = "classifier." + + port_weights( + keras_model.output_dense.kernel, + prefix + "weight", + hook_fn=lambda x, _: x.T, + ) + port_weights( + keras_model.output_dense.bias, + prefix + "bias", + ) + + +def convert_image_converter(hf_image_processor): + config = hf_image_processor.to_dict() + image_size = (config["size"]["height"], config["size"]["width"]) + std = config["image_std"] + mean = config["image_mean"] + return ViTImageConverter( + image_size=image_size, + scale=config["rescale_factor"], + norm_mean=mean, + norm_std=std, + interpolation="bilinear", # ViT defaults to bilinear resampling. + ) + + +def validate_output( + keras_model, + keras_image_converter, + hf_model, + hf_image_processor, + head_weights=False, +): + file = keras.utils.get_file( + origin=("http://images.cocodataset.org/val2017/000000039769.jpg") + ) + image = Image.open(file) + + # Preprocess with hf. + hf_inputs = hf_image_processor( + image, + return_tensors="pt", + ) + hf_preprocessed = hf_inputs["pixel_values"].detach().cpu().numpy() + + # Preprocess with keras. + images = np.expand_dims(np.array(image).astype("float32"), axis=0) + images = np.concatenate([images, images], axis=0) + images = keras_image_converter(images) + keras_preprocessed = keras.ops.convert_to_numpy(images) + + # Call with hf. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + hf_inputs["pixel_values"] = torch.from_numpy( + keras.ops.convert_to_numpy( + keras.ops.transpose(keras_preprocessed, (0, 3, 1, 2)) + ) + ) + hf_outputs = hf_model(**hf_inputs) + if head_weights: + hf_vision_logits = hf_outputs.logits.detach().cpu().numpy() + + else: + hf_vision_logits = hf_outputs.last_hidden_state.detach().cpu().numpy() + + # Call with keras. + keras_outputs = keras_model(keras_preprocessed) + keras_vision_logits = keras.ops.convert_to_numpy(keras_outputs) + + print("🔶 Keras output:", keras_vision_logits[0, :10]) + print("🔶 HF output:", hf_vision_logits[0, :10]) + if head_weights: + print( + "🔶 HF top 5 ImageNet outputs:", + keras_hub.utils.decode_imagenet_predictions(hf_vision_logits), + ) + print( + "🔶 Keras top 5 ImageNet outputs:", + keras_hub.utils.decode_imagenet_predictions(keras_outputs), + ) + modeling_diff = np.mean(np.abs(keras_vision_logits - hf_vision_logits)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean( + np.abs(keras_preprocessed - np.transpose(hf_preprocessed, (0, 2, 3, 1))) + ) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + print(f"🏃 Coverting {preset}") + + # Load huggingface model. + hf_model = ViTForImageClassification.from_pretrained(hf_preset) + hf_preprocessor = ViTImageProcessor.from_pretrained(hf_preset) + hf_model.eval() + + keras_backbone, hf_config = convert_model(hf_model) + keras_image_converter = convert_image_converter(hf_preprocessor) + keras_image_preprocessor = ViTImageClassifierPreprocessor( + image_converter=keras_image_converter + ) + print("✅ KerasHub model loaded.") + + convert_backbone_weights(keras_backbone, hf_model) + print("✅ Backbone weights converted.") + + if hf_config["architectures"][0] == "ViTForImageClassification": + keras_model = ViTImageClassifier( + backbone=keras_backbone, num_classes=len(hf_config["id2label"]) + ) + convert_head_weights(keras_model, hf_model) + print("✅ Head weights converted.") + validate_output( + keras_model, + keras_image_converter, + hf_model, + hf_preprocessor, + head_weights=True, + ) + print("✅ Output validated.") + keras_model.preprocessor = keras_image_preprocessor + keras_model.save_to_preset(f"./{preset}") + else: + hf_model = hf_model.vit + validate_output( + keras_backbone, + keras_image_converter, + hf_model, + hf_preprocessor, + ) + print("✅ Output validated.") + keras_backbone.save_to_preset(f"./{preset}") + keras_image_preprocessor.save_to_preset(f"./{preset}") + + print(f"🏁 Preset saved to ./{preset}.") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main)