diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index ab04d8eae0..30736594d0 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -35,6 +35,7 @@ ) from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor from keras_nlp.models.bert.bert_tokenizer import BertTokenizer +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_classifier import ( DebertaV3Classifier, diff --git a/keras_nlp/models/bloom/__init__.py b/keras_nlp/models/bloom/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/models/bloom/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/bloom/bloom_attention.py b/keras_nlp/models/bloom/bloom_attention.py new file mode 100644 index 0000000000..7af2e7a34d --- /dev/null +++ b/keras_nlp/models/bloom/bloom_attention.py @@ -0,0 +1,212 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.utils.keras_utils import clone_initializer + + +class BloomAttention(keras.layers.Layer): + def __init__( + self, + num_heads, + dropout=0.0, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.dropout = dropout + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + + def build(self, inputs_shape): + batch_size, seq_length, hidden_dim = inputs_shape + + self.head_dim = hidden_dim // self.num_heads + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + self._query_dense = keras.layers.EinsumDense( + equation="btm,mnh->btnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="query_dense", + ) + self._query_dense.build(inputs_shape) + + self._key_dense = keras.layers.EinsumDense( + equation="bsm,mnh->bsnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="key_dense", + ) + self._key_dense.build(inputs_shape) + + self._value_dense = keras.layers.EinsumDense( + equation="bsm,mnh->bsnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="value_dense", + ) + self._value_dense.build(inputs_shape) + + self._output_dense = keras.layers.Dense( + hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="output_dense", + ) + self._output_dense.build(inputs_shape) + + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, dtype=self.dtype_policy, name="dropout" + ) + self._softmax = keras.layers.Softmax( + dtype=self.dtype_policy, name="softmax" + ) + + self.built = True + + @staticmethod + def _build_alibi_tensor(num_heads, seq_length, alibi_bias_max=8): + # this function is adopted from fairseq + # https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** ( + -(2 ** -(math.log2(n) - math.log2(alibi_bias_max))) + ) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][ + : n - closest_power_of_2 + ] + ) + + slopes = ops.convert_to_tensor(get_slopes(num_heads), dtype=float) + slopes = ops.expand_dims(slopes, 1) + + alibi = slopes * ops.expand_dims(ops.arange(seq_length, dtype=float), 0) + alibi = ops.expand_dims(alibi, 1) + alibi = ops.expand_dims(alibi, 0) + + return alibi + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + ): + batch_size, seq_length, hidden_dim = ops.shape(hidden_states) + + query = self._query_dense(hidden_states) + key = self._key_dense(hidden_states) + value = self._value_dense(hidden_states) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key) + value = ops.slice_update(value_cache, start, value) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + + # query (batch_size, num_heads, query_length, head_dim) + query = ops.transpose(query, [0, 2, 1, 3]) + # value (batch_size, num_heads, kv_length, head_dim) + value = ops.transpose(value, [0, 2, 1, 3]) + # key (batch_size, num_heads, head_dim, kv_length) + key = ops.transpose(key, [0, 2, 3, 1]) + + alibi = self._build_alibi_tensor( + num_heads=self.num_heads, seq_length=seq_length + ) + + scores = ( + ops.matmul(query, key) * self.inv_norm_factor + alibi + ) # [batch_size, num_heads, query_length, kv_length] + + scores = self._softmax(scores, ops.expand_dims(attention_mask, 1)) + + scores = self._dropout_layer(scores) + + attention_output = ops.matmul( + scores, value + ) # [batch_size, num_heads, query_length, head_dim] + + attention_output = ops.transpose( + attention_output, [0, 2, 1, 3] + ) # [batch_size, query_length, num_heads, head_dim] + attention_output = ops.reshape( + attention_output, + [batch_size, seq_length, self.num_heads * self.head_dim], + ) # [batch_size, query_length, hidden_dim] + + attention_output = self._output_dense(attention_output) + attention_output = self._dropout_layer(attention_output) + + if cache is not None: + return attention_output, cache + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "dropout": self.dropout, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + } + ) + return config diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py new file mode 100644 index 0000000000..e3d66998bc --- /dev/null +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -0,0 +1,153 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.bloom.bloom_decoder import BloomDecoder + + +def _bloom_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_nlp_export("keras_nlp.models.BloomBackbone") +class BloomBackbone(Backbone): + """A Bloom decoder network. + + This network implements a Transformer-based decoder network, BigScience + Language Open-science Open-access Multilingual (BLOOM), as descriped in + ["BLOOM: A 176B-Parameter Open-Access Multilingual Language Model"](https://arxiv.org/pdf/2211.05100.pdf). + + The default constructor gives a fully customizable, randomly initialized + Bloom model with any number of layers, heads, and embedding dimensions. To + load preset architectures and weights, use the `from_preset()` constructor. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The dimensionality of the embeddings and hidden states. + intermediate_dim: int. The output dimension of the first Dense layer in + the MLP network of each transformer. + dropout: float. Dropout probability for the Transformer decoder. + layer_norm_epsilon: float. Epsilon for the layer normalization layers in + the transformer decoder. + max_sequence_length: int. The maximum sequence length that this decoder + can consume. + + Examples: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Randomly initialized BLOOM decoder with a custom config. + model = keras_nlp.models.BloomBackbone( + vocabulary_size=10, + num_layers=2, + num_heads=2, + hidden_dim=32, + intermediate_dim=32*4, + dropout=0.0, + layer_norm_epsilon=1e-5, + max_sequence_length=128, + ) + model(input_data) + ``` + + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + hidden_dim, + intermediate_dim, + dropout=0.0, + layer_norm_epsilon=1e-5, + max_sequence_length=512, + **kwargs, + ): + token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # Embed tokens + token_embedding_layer = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=_bloom_kernel_initializer(stddev=0.02), + tie_weights=False, + name="token_embedding", + ) + token_embedding = token_embedding_layer(token_ids) + + x = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, name="token_embedding_layernorm" + )(token_embedding) + + for i in range(num_layers): + x = BloomDecoder( + num_heads=num_heads, + intermediate_dim=intermediate_dim, + dropout=dropout, + layer_norm_epsilon=layer_norm_epsilon, + name=f"transformer_layer_{i}", + )(x, decoder_padding_mask=padding_mask) + + sequence_output = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, name="final_layernorm" + )(x) + + super().__init__( + inputs={ + "token_ids": token_ids, + "padding_mask": padding_mask, + }, + outputs=sequence_output, + **kwargs, + ) + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.max_sequence_length = max_sequence_length + self.token_embedding = token_embedding_layer + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "max_sequence_length": self.max_sequence_length, + } + ) + return config diff --git a/keras_nlp/models/bloom/bloom_backbone_test.py b/keras_nlp/models/bloom/bloom_backbone_test.py new file mode 100644 index 0000000000..99cb9d357e --- /dev/null +++ b/keras_nlp/models/bloom/bloom_backbone_test.py @@ -0,0 +1,51 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.tests.test_case import TestCase + + +class BloomTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_heads": 4, + "hidden_dim": 8, + "intermediate_dim": 32, + "max_sequence_length": 10, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=BloomBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=BloomBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py new file mode 100644 index 0000000000..b3f8b80da7 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -0,0 +1,204 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from keras_nlp.backend import keras +# from keras_nlp.backend import ops +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.bloom.bloom_attention import BloomAttention +from keras_nlp.utils.keras_utils import clone_initializer + + +class BloomDecoder(keras.layers.Layer): + def __init__( + self, + num_heads, + intermediate_dim, + dropout=0.0, + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + **kwargs, + ): + super().__init__(**kwargs) + + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + + def build(self, decoder_sequence_shape): + hidden_dim = decoder_sequence_shape[-1] + head_dim = int(hidden_dim // self.num_heads) + + if head_dim * self.num_heads != hidden_dim: + raise ValueError( + f"`hidden_dim` must be divisible by num_heads (got `hidden_dim`" + f": {hidden_dim} and `num_heads`: {self.num_heads})." + ) + + self._pre_attention_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_layernorm", + ) + self._pre_attention_layernorm.build(decoder_sequence_shape) + + self._self_attention_layer = BloomAttention( + num_heads=self.num_heads, + dropout=self.dropout, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="self_attention", + ) + self._self_attention_layer.build(decoder_sequence_shape) + + self._post_attention_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self._post_attention_layernorm.build(decoder_sequence_shape) + + self._mlp_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="mlp_intermediate_dense", + ) + self._mlp_intermediate_dense.build(decoder_sequence_shape) + + self._mlp_output_dense = keras.layers.Dense( + hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="mlp_output_dense", + ) + intermediate_shape = list(decoder_sequence_shape) + intermediate_shape[-1] = self.intermediate_dim + self._mlp_output_dense.build(tuple(intermediate_shape)) + + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, dtype=self.dtype_policy, name="dropout" + ) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + attention_cache=None, + attention_cache_update_index=None, + use_causal_mask=True, + ): + self_attention_mask = self._compute_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + use_causal_mask=use_causal_mask, + attention_cache=attention_cache, + attention_cache_update_index=attention_cache_update_index, + ) + + residual = decoder_sequence + x = self._pre_attention_layernorm(decoder_sequence) + + attention_output = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=attention_cache, + cache_update_index=attention_cache_update_index, + ) + + if attention_cache is None: + x = attention_output + else: + x, attention_cache = attention_output + + x = x + residual + residual = x + x = self._post_attention_layernorm(x) + x = self._mlp_intermediate_dense(x) + x = keras.activations.gelu(x, approximate=True) + x = self._mlp_output_dense(x) + x = self._dropout_layer(x) + x = x + residual + + if attention_cache is not None: + return x, attention_cache + else: + return x + + def _compute_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + use_causal_mask, + attention_cache, + attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + if use_causal_mask: + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + if attention_cache is not None: + input_length = ops.shape(attention_cache)[2] + + causal_mask = compute_causal_mask( + batch_size, + input_length, + output_length, + 0 + if attention_cache_update_index is None + else attention_cache_update_index, + ) + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + return decoder_mask + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + } + ) + return config diff --git a/tools/checkpoint_conversion/convert_bloom_checkpoints.py b/tools/checkpoint_conversion/convert_bloom_checkpoints.py new file mode 100644 index 0000000000..1fc895c54c --- /dev/null +++ b/tools/checkpoint_conversion/convert_bloom_checkpoints.py @@ -0,0 +1,195 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import transformers +from absl import app +from absl import flags +from checkpoint_conversion_utils import get_md5_checksum + +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "bloom_tiny": "bigscience/bloom-560m", + "bloom_extra_small": "bigscience/bloom-1b1", + "bloom_small": "bigscience/bloom-1b7", + "bloom_meduim": "bigscience/bloom-3b", + "bloom_large": "bigscience/bloom-7b1", + "bloom_extra_large": "bigscience/bloom", +} + +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' +) +flags.mark_flag_as_required("preset") + + +def convert_checkpoints(hf_model): + # get huggingface model configuration. + hf_config = hf_model.config.to_dict() + + cfg = {} + cfg["vocabulary_size"] = hf_config["vocab_size"] + cfg["num_layers"] = hf_config["n_layer"] + cfg["num_heads"] = hf_config["n_head"] + cfg["hidden_dim"] = hf_config["hidden_size"] + cfg["intermediate_dim"] = hf_config["hidden_size"] * 4 + cfg["dropout"] = hf_config["hidden_dropout"] + cfg["layer_norm_epsilon"] = hf_config["layer_norm_epsilon"] + + hidden_dim = cfg["hidden_dim"] + num_heads = cfg["num_heads"] + head_dim = hidden_dim // num_heads + + # Intialize Bloom model with the weights. + keras_model = BloomBackbone(**cfg) + + # get huggingface model weights. + hf_wts = hf_model.state_dict() + + # assign huggingface weights to the keras model. + # Embedding layer. + keras_model.get_layer("token_embedding").embeddings.assign( + hf_wts["word_embeddings.weight"] + ) + # LayerNorm. + keras_model.get_layer("token_embedding_layernorm").gamma.assign( + hf_wts["word_embeddings_layernorm.weight"] + ) + keras_model.get_layer("token_embedding_layernorm").beta.assign( + hf_wts["word_embeddings_layernorm.bias"] + ) + + keras_model.get_layer("final_layernorm").gamma.assign(hf_wts["ln_f.weight"]) + keras_model.get_layer("final_layernorm").beta.assign(hf_wts["ln_f.bias"]) + + # Decoder layers. + for i in range(cfg["num_layers"]): + decoder_layer = keras_model.get_layer(f"transformer_layer_{i}") + # LayrNorm. + decoder_layer._pre_attention_layernorm.gamma.assign( + hf_wts[f"h.{i}.input_layernorm.weight"] + ) + decoder_layer._pre_attention_layernorm.beta.assign( + hf_wts[f"h.{i}.input_layernorm.bias"] + ) + decoder_layer._post_attention_layernorm.gamma.assign( + hf_wts[f"h.{i}.post_attention_layernorm.weight"] + ) + decoder_layer._post_attention_layernorm.beta.assign( + hf_wts[f"h.{i}.post_attention_layernorm.bias"] + ) + + # Attention layer. + attention_layer = decoder_layer._self_attention_layer + + fused_qkv_kernal = hf_wts[ + f"h.{i}.self_attention.query_key_value.weight" + ].T + fused_qkv_kernal = fused_qkv_kernal.view( + hidden_dim, num_heads, 3, head_dim + ) + query_kernal = fused_qkv_kernal[..., 0, :] + key_kernal = fused_qkv_kernal[..., 1, :] + value_kernl = fused_qkv_kernal[..., 2, :] + + fused_qkv_bais = hf_wts[f"h.{i}.self_attention.query_key_value.bias"] + fused_qkv_bais = fused_qkv_bais.view(num_heads, 3, head_dim) + query_bais = fused_qkv_bais[:, 0, :] + key_bais = fused_qkv_bais[:, 1, :] + value_bais = fused_qkv_bais[:, 2, :] + + attention_layer._query_dense.kernel.assign(query_kernal) + attention_layer._query_dense.bias.assign(query_bais) + attention_layer._key_dense.kernel.assign(key_kernal) + attention_layer._key_dense.bias.assign(key_bais) + attention_layer._value_dense.kernel.assign(value_kernl) + attention_layer._value_dense.bias.assign(value_bais) + + attention_layer._output_dense.kernel.assign( + hf_wts[f"h.{i}.self_attention.dense.weight"].T + ) + attention_layer._output_dense.bias.assign( + hf_wts[f"h.{i}.self_attention.dense.bias"] + ) + + # mlp. + decoder_layer._mlp_intermediate_dense.kernel.assign( + hf_wts[f"h.{i}.mlp.dense_h_to_4h.weight"].T + ) + decoder_layer._mlp_intermediate_dense.bias.assign( + hf_wts[f"h.{i}.mlp.dense_h_to_4h.bias"] + ) + decoder_layer._mlp_output_dense.kernel.assign( + hf_wts[f"h.{i}.mlp.dense_4h_to_h.weight"].T + ) + decoder_layer._mlp_output_dense.bias.assign( + hf_wts[f"h.{i}.mlp.dense_4h_to_h.bias"] + ) + + # Save the model. + print(f"\n-> Saving KerasNLP model weights to `{FLAGS.preset}.weights.h5`.") + keras_model.save_weights(f"{FLAGS.preset}.weights.h5") + + return keras_model + + +def check_output(keras_model, hf_model): + hf_model_input = { + "input_ids": torch.tensor([[59414, 15, 2670, 35433, 632, 207595]]), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1]]), + } + + hf_model_outputs = hf_model(**hf_model_input) + hf_model_outputs = hf_model_outputs.last_hidden_state + hf_model_outputs = hf_model_outputs.detach().numpy() + + keras_model_input = { + "token_ids": torch.tensor([[59414, 15, 2670, 35433, 632, 207595]]), + "padding_mask": torch.tensor([[1, 1, 1, 1, 1, 1]]), + } + + keras_model_outputs = keras_model.predict(keras_model_input) + + # Comparing the outputs. + print("KerasNLP output:", keras_model_outputs[0, 0, :10]) + print("HF output:", hf_model_outputs[0, 0, :10]) + print("Difference:", np.mean(keras_model_outputs - hf_model_outputs)) + + # Show the MD5 checksum of the model weights. + print("Model md5sum: ", get_md5_checksum(f"./{FLAGS.preset}.weights.h5")) + + +def main(_): + assert ( + FLAGS.preset in PRESET_MAP.keys() + ), f'Invalid preset {FLAGS.preset}. Must be one of {",".join(PRESET_MAP.keys())}' + + hf_model_name = PRESET_MAP[FLAGS.preset] + + print("\n-> Loading HF model.") + hf_model = transformers.AutoModel.from_pretrained(hf_model_name) + + print("\n-> Converting model checkpoint.") + keras_model = convert_checkpoints(hf_model) + + print("\n-> Checking keras model output.") + check_output(keras_model, hf_model) + + +if __name__ == "__main__": + app.run(main)