diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index 5685fce2cb..aa4517d038 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -14,6 +14,8 @@ """T5 backbone model.""" +import copy + import tensorflow as tf from tensorflow import keras @@ -21,6 +23,7 @@ from keras_nlp.layers.transformer_layer_utils import compute_causal_mask from keras_nlp.models.backbone import Backbone from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm +from keras_nlp.models.t5.t5_presets import backbone_presets from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer from keras_nlp.utils.python_utils import classproperty @@ -54,6 +57,9 @@ class T5Backbone(Backbone): hidden_dim: int. The hidden size of the Transformer layers. intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each Transformer layer. + key_value_dim: int. The dimension of each head of the key/value + projections in the multi-head attention layers. Defaults to + hidden_dim / num_heads. dropout: float. Dropout probability for the Transformer layers. activation: activation function (or activation string name). The activation to be used in the inner dense blocks of the @@ -75,6 +81,7 @@ def __init__( num_heads, hidden_dim, intermediate_dim, + key_value_dim=None, dropout=0.1, activation="gelu", use_gated_activation=True, @@ -123,6 +130,7 @@ def __init__( is_decoder=False, hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, + key_value_dim=key_value_dim or hidden_dim // num_heads, dropout=dropout, activation=activation, layer_norm_epsilon=layer_norm_epsilon, @@ -167,6 +175,7 @@ def __init__( is_decoder=True, hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, + key_value_dim=key_value_dim or hidden_dim // num_heads, dropout=dropout, activation=activation, layer_norm_epsilon=layer_norm_epsilon, @@ -237,4 +246,4 @@ def token_embedding(self): @classproperty def presets(cls): - return {} + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/t5/t5_multi_head_attention.py b/keras_nlp/models/t5/t5_multi_head_attention.py index 5f8f2ad020..48e20c7b30 100644 --- a/keras_nlp/models/t5/t5_multi_head_attention.py +++ b/keras_nlp/models/t5/t5_multi_head_attention.py @@ -31,6 +31,7 @@ def __init__( self, is_decoder, hidden_dim, + key_value_dim, num_heads, dropout, use_relative_attention_bias=False, @@ -39,7 +40,7 @@ def __init__( super().__init__(**kwargs) self.is_decoder = is_decoder self.hidden_dim = hidden_dim - self.key_value_dim = hidden_dim // num_heads + self.key_value_dim = key_value_dim self.num_heads = num_heads self.use_relative_attention_bias = use_relative_attention_bias diff --git a/keras_nlp/models/t5/t5_presets.py b/keras_nlp/models/t5/t5_presets.py new file mode 100644 index 0000000000..e128410858 --- /dev/null +++ b/keras_nlp/models/t5/t5_presets.py @@ -0,0 +1,138 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""XLM-RoBERTa model preset configurations.""" + +backbone_presets = { + "t5_small_en": { + "metadata": { + "description": ( + "8-layer T5 model. Trained on the Colossal Clean Crawled " + "Corpus (C4)." + ), + "params": 0, + "official_name": "T5", + "path": "t5", + "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md", + }, + "config": { + "vocabulary_size": 32128, + "num_layers": 8, + "num_heads": 6, + "hidden_dim": 512, + "intermediate_dim": 1024, + "key_value_dim": 64, + "dropout": 0.1, + "activation": "gelu", + "use_gated_activation": True, + "layer_norm_epsilon": 1e-06, + }, + "preprocessor_config": {}, + }, + "t5_base_en": { + "metadata": { + "description": ( + "12-layer T5 model. Trained on the Colossal Clean Crawled " + "Corpus (C4)." + ), + "params": 0, + "official_name": "T5", + "path": "t5", + "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md", + }, + "config": { + "vocabulary_size": 32128, + "num_layers": 12, + "num_heads": 12, + "hidden_dim": 768, + "intermediate_dim": 2048, + "dropout": 0.1, + "activation": "gelu", + "use_gated_activation": True, + "layer_norm_epsilon": 1e-06, + }, + "preprocessor_config": {}, + }, + "t5_large_en": { + "metadata": { + "description": ( + "24-layer T5 model. Trained on the Colossal Clean Crawled " + "Corpus (C4)." + ), + "params": 0, + "official_name": "T5", + "path": "t5", + "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md", + }, + "config": { + "vocabulary_size": 32128, + "num_layers": 24, + "num_heads": 16, + "hidden_dim": 1024, + "intermediate_dim": 2816, + "dropout": 0.1, + "activation": "gelu", + "use_gated_activation": True, + "layer_norm_epsilon": 1e-06, + }, + "preprocessor_config": {}, + }, + "t5_extra_large_en": { + "metadata": { + "description": ( + "24-layer T5 model. Trained on the Colossal Clean Crawled " + "Corpus (C4)." + ), + "params": 0, + "official_name": "T5", + "path": "t5", + "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md", + }, + "config": { + "vocabulary_size": 32128, + "num_layers": 24, + "num_heads": 32, + "hidden_dim": 2048, + "intermediate_dim": 5120, + "dropout": 0.1, + "activation": "gelu", + "use_gated_activation": True, + "layer_norm_epsilon": 1e-06, + }, + "preprocessor_config": {}, + }, + "t5_extra_extra_large_en": { + "metadata": { + "description": ( + "24-layer T5 model. Trained on the Colossal Clean Crawled " + "Corpus (C4)." + ), + "params": 0, + "official_name": "T5", + "path": "t5", + "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md", + }, + "config": { + "vocabulary_size": 32128, + "num_layers": 24, + "num_heads": 64, + "hidden_dim": 4096, + "intermediate_dim": 10240, + "dropout": 0.1, + "activation": "gelu", + "use_gated_activation": True, + "layer_norm_epsilon": 1e-06, + }, + "preprocessor_config": {}, + }, +} diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index f70010e42d..b69dfa5c51 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -23,6 +23,7 @@ def __init__( is_decoder, hidden_dim, intermediate_dim, + key_value_dim, dropout, activation, layer_norm_epsilon, @@ -36,10 +37,11 @@ def __init__( self.use_gated_activation = use_gated_activation self.self_attention = T5MultiHeadAttention( - is_decoder, - hidden_dim, - num_heads, - dropout, + is_decoder=is_decoder, + hidden_dim=hidden_dim, + key_value_dim=key_value_dim, + num_heads=num_heads, + dropout=dropout, use_relative_attention_bias=use_relative_attention_bias, ) self.self_attention_layernorm = T5LayerNorm(layer_norm_epsilon) @@ -47,15 +49,25 @@ def __init__( if self.is_decoder: self.cross_attention = T5MultiHeadAttention( - is_decoder, - hidden_dim, - num_heads, - dropout, + is_decoder=is_decoder, + hidden_dim=hidden_dim, + key_value_dim=key_value_dim, + num_heads=num_heads, + dropout=dropout, use_relative_attention_bias=False, ) self.cross_attention_layernorm = T5LayerNorm(layer_norm_epsilon) self.cross_attention_dropout = keras.layers.Dropout(dropout) + if activation == "gelu": + + def approx_gelu(x): + return keras.activations.gelu(x, approximate=True) + + activation = approx_gelu + else: + activation = keras.activations.get(activation) + self.input_projector = keras.layers.Dense( intermediate_dim, use_bias=False, @@ -123,7 +135,7 @@ def call( x = self.layer_norm(x) if self.use_gated_activation: hidden_activation = self.input_projector(x) - hidden_linear = self.gate_projector(hidden_states) + hidden_linear = self.gate_projector(x) x = hidden_activation * hidden_linear else: x = self.input_projector(x) diff --git a/tools/checkpoint_conversion/checkpoint_conversion_utils.py b/tools/checkpoint_conversion/checkpoint_conversion_utils.py index 2fea715727..a12b0ee6e3 100644 --- a/tools/checkpoint_conversion/checkpoint_conversion_utils.py +++ b/tools/checkpoint_conversion/checkpoint_conversion_utils.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import hashlib +import inspect + +import keras as keraslib +from keras.engine.base_layer_utils import make_variable def get_md5_checksum(file_path): @@ -20,3 +24,42 @@ def get_md5_checksum(file_path): for byte_block in iter(lambda: f.read(4096), b""): md5_hash.update(byte_block) return md5_hash.hexdigest() + + +def port_weights_by_creation_order(source_model_fn, dest_model_fn, debug=False): + """Assign weights between models by intercepting variable creation. + + For each model makes a flat list of all variables created, in order.""" + ALL_VARS = [] + + def make_var(name, shape=None, **kwargs): + if debug: + stack = inspect.stack() + instance = stack[1][0].f_locals["self"] + cls = instance.__class__.__name__ + print(f"Class {cls} creating {name} with shape {shape}.") + + v = make_variable(name, shape=shape, **kwargs) + ALL_VARS.append(v) + return v + + # Patch make variable. + keraslib.engine.base_layer_utils.make_variable = make_var + + source_model = source_model_fn() + source_model_vars = ALL_VARS[:] + + [ALL_VARS.pop(0) for _ in list(ALL_VARS)] + dest_model = dest_model_fn() + if len(ALL_VARS) != len(source_model_vars): + raise ValueError( + f"Variable counts do not match. 1st model: {len(source_model_vars)} " + "vars, 2nd model: {len(ALL_VARS)} vars" + ) + for v1, v2 in zip(source_model_vars, ALL_VARS): + v2.assign(v1.numpy()) + + # Unpatch make variable. + keraslib.engine.base_layer_utils.make_variable = make_variable + + return source_model, dest_model diff --git a/tools/checkpoint_conversion/convert_t5_checkpoints.py b/tools/checkpoint_conversion/convert_t5_checkpoints.py new file mode 100644 index 0000000000..b04a94112f --- /dev/null +++ b/tools/checkpoint_conversion/convert_t5_checkpoints.py @@ -0,0 +1,179 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil + +import numpy as np +import tensorflow as tf +import transformers +from absl import app +from absl import flags +from keras.utils.layer_utils import count_params + +import keras_nlp +from tools.checkpoint_conversion.checkpoint_conversion_utils import ( + get_md5_checksum, +) +from tools.checkpoint_conversion.checkpoint_conversion_utils import ( + port_weights_by_creation_order, +) + +PRESET_MAP = { + "t5_small_en": "google/t5-v1_1-small", + "t5_base_en": "google/t5-v1_1-base", + "t5_large_en": "google/t5-v1_1-large", + "t5_extra_large_en": "google/t5-v1_1-xl", + "t5_extra_extra_large_en": "google/t5-v1_1-xxl", +} + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "preset", "t5_base", f'Must be one of {",".join(PRESET_MAP.keys())}' +) + + +def extract_vocab(hf_tokenizer): + proto_path = f"./{FLAGS.preset}/vocab.spm" + print(f"\n-> Save KerasNLP vocab to `{proto_path}`.") + + # Huggingface has a save_vocabulary function but it's not byte-for-byte + # with the source. Instead copy the original downloaded file directly. + shutil.copyfile( + transformers.utils.hub.get_file_from_repo( + hf_tokenizer.name_or_path, "spiece.model" + ), + proto_path, + ) + + keras_tokenizer = keras_nlp.models.T5Tokenizer( + proto=proto_path, + ) + + print("-> Print MD5 checksum of the vocab files.") + print(f"`{proto_path}` md5sum: ", get_md5_checksum(proto_path)) + + return keras_tokenizer + + +def check_output( + keras_model, + keras_tokenizer, + hf_model, + hf_tokenizer, +): + print("\n-> Compare the outputs.") + encoder_input = ["the quick brown fox jumped."] + decoder_input = ["the quick brown fox fell."] + + sequence_length = 12 + + # KerasNLP Tokenization + packer = keras_nlp.layers.StartEndPacker( + sequence_length=sequence_length, + pad_value=keras_tokenizer.pad_token_id, + end_value=keras_tokenizer.end_token_id, + ) + encoder_token_ids = packer(keras_tokenizer(encoder_input)) + encoder_padding_mask = encoder_token_ids != keras_tokenizer.pad_token_id + decoder_token_ids = packer(keras_tokenizer(decoder_input)) + decoder_padding_mask = decoder_token_ids != keras_tokenizer.pad_token_id + keras_inputs = { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } + + # HF Tokenization. + hf_encoder_inputs = hf_tokenizer( + encoder_input, + padding="max_length", + max_length=sequence_length, + return_tensors="tf", + ) + hf_decoder_inputs = hf_tokenizer( + decoder_input, + padding="max_length", + max_length=sequence_length, + return_tensors="tf", + ) + hf_inputs = { + "input_ids": hf_encoder_inputs["input_ids"], + "attention_mask": hf_encoder_inputs["attention_mask"], + "decoder_input_ids": hf_decoder_inputs["input_ids"], + "decoder_attention_mask": hf_decoder_inputs["attention_mask"], + } + + # Compare tokenized inputs. This should be a compete match. + print("-> KerasNLP inputs:") + for k, v in keras_inputs.items(): + print(k, v) + print("-> HF inputs:") + for k, v in hf_inputs.items(): + print(k, v) + + # Forward pass + keras_outputs = keras_model(keras_inputs) + hf_outputs = hf_model(**hf_inputs) + + # Only compare non-padded token ids. + keras_outputs = keras_outputs["decoder_sequence_output"] + keras_outputs = tf.gather_nd(keras_outputs, tf.where(decoder_padding_mask)) + hf_outputs = hf_outputs.last_hidden_state + hf_outputs = tf.gather_nd(hf_outputs, tf.where(decoder_padding_mask)) + + print("-> KerasNLP output:", keras_outputs[0, :5]) + print("-> HF output:", hf_outputs[0, :5]) + np.testing.assert_allclose( + keras_outputs.numpy(), hf_outputs.numpy(), atol=1e-5 + ) + + +def main(_): + hf_id = PRESET_MAP[FLAGS.preset] + shutil.rmtree(f"./{FLAGS.preset}", ignore_errors=True) + os.mkdir(f"./{FLAGS.preset}") + + print("\n-> Convert weights.") + hf_model, keras_model = port_weights_by_creation_order( + lambda: transformers.TFAutoModel.from_pretrained(hf_id), + lambda: keras_nlp.models.T5Backbone.from_preset( + FLAGS.preset, load_weights=False + ), + ) + + # Save the model. + model_path = f"./{FLAGS.preset}/model.h5" + print(f"\n-> Save KerasNLP model weights to `{model_path}`.") + keras_model.save_weights(model_path) + print("-> Print MD5 checksum of the model weights files.") + print(f"`{model_path}` md5sum: ", get_md5_checksum(model_path)) + print(f"-> Param count {count_params(keras_model.weights)}") + + print("\n-> Convert vocab.") + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_id) + keras_tokenizer = extract_vocab(hf_tokenizer) + + check_output( + keras_model, + keras_tokenizer, + hf_model, + hf_tokenizer, + ) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)