-
Notifications
You must be signed in to change notification settings - Fork 315
Gpt2safetensors #2459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
LakshmiKalaKadali
wants to merge
17
commits into
master
Choose a base branch
from
gpt2safetensors
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+293
−21
Open
Gpt2safetensors #2459
Changes from 2 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
5add1a3
GPT2 kerashub to hf safetensor conversion
LakshmiKalaKadali 7146a5c
GPT2 kerashub to hf safe tensor conversion
LakshmiKalaKadali a8b6d07
Update keras_hub/src/utils/transformers/export/gpt2.py
LakshmiKalaKadali 3c59953
Update keras_hub/src/utils/transformers/export/gpt2.py
LakshmiKalaKadali adcc364
Update keras_hub/src/utils/transformers/export/gpt2_test.py
LakshmiKalaKadali d8e71af
Update keras_hub/src/utils/transformers/export/gpt2_test.py
LakshmiKalaKadali 45e2e9d
Update tools/checkpoint_conversion/convert_gpt2_checkpoints.py
LakshmiKalaKadali 9f03655
Fix GPT-2 export to handle both dict and object configs
LakshmiKalaKadali a140862
Fix GPT-2 export to handle both dict and object configs
LakshmiKalaKadali 621adc3
Revert " Fix GPT-2 export to handle both dict and object configs"
LakshmiKalaKadali f87f1cf
Remove virtual environment files from repository and update .gitignore
LakshmiKalaKadali 1263431
removing dependency on gemma.py
LakshmiKalaKadali ae31f19
Test failures are fixed
LakshmiKalaKadali 57fe8c7
Jax backend errors are fixed
LakshmiKalaKadali f5b2024
Updated export logic
LakshmiKalaKadali c5dd3c1
Reverted gitignore file and updated few comments in hf_exporter file
LakshmiKalaKadali 6cc159d
Revert .gitignore to match master branch
LakshmiKalaKadali File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| import tensorflow as tf | ||
LakshmiKalaKadali marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import transformers | ||
|
|
||
|
|
||
| def get_gpt2_config(keras_model): | ||
| """Convert Keras GPT-2 config to Hugging Face GPT2Config.""" | ||
| return transformers.GPT2Config( | ||
| vocab_size=keras_model.vocabulary_size, | ||
| n_positions=keras_model.max_sequence_length, | ||
| n_embd=keras_model.hidden_dim, | ||
| n_layer=keras_model.num_layers, | ||
| n_head=keras_model.num_heads, | ||
| n_inner=keras_model.intermediate_dim, | ||
| activation_function="gelu_new", | ||
| resid_pdrop=0.1, | ||
| embd_pdrop=0.1, | ||
| attn_pdrop=0.1, | ||
| layer_norm_epsilon=1e-5, | ||
| initializer_range=0.02, | ||
| summary_type="cls_index", | ||
| summary_use_proj=True, | ||
| summary_activation=None, | ||
| summary_proj_to_labels=True, | ||
| summary_first_dropout=0.1, | ||
| scale_attn_weights=True, | ||
| use_cache=True, | ||
| bos_token_id=50256, | ||
| eos_token_id=50256, | ||
| ) | ||
|
|
||
|
|
||
| def get_gpt2_weights_map(keras_model, include_lm_head=False): | ||
| """Create a weights map for a given GPT-2 model.""" | ||
| weights_map = {} | ||
|
|
||
| # Token and position embeddings | ||
| weights_map["transformer.wte.weight"] = keras_model.get_layer( | ||
| "token_embedding" | ||
| ).embeddings | ||
| weights_map["transformer.wpe.weight"] = keras_model.get_layer( | ||
| "position_embedding" | ||
| ).position_embeddings | ||
|
|
||
| for i in range(keras_model.num_layers): | ||
| # Attention weights | ||
| q_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._query_dense.kernel | ||
| k_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._key_dense.kernel | ||
| v_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._value_dense.kernel | ||
LakshmiKalaKadali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| q_b = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._query_dense.bias | ||
| k_b = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._key_dense.bias | ||
| v_b = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._value_dense.bias | ||
|
|
||
| q_w = tf.reshape(q_w, (keras_model.hidden_dim, keras_model.hidden_dim)) | ||
| k_w = tf.reshape(k_w, (keras_model.hidden_dim, keras_model.hidden_dim)) | ||
| v_w = tf.reshape(v_w, (keras_model.hidden_dim, keras_model.hidden_dim)) | ||
LakshmiKalaKadali marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| c_attn_w = tf.concat([q_w, k_w, v_w], axis=-1) | ||
| weights_map[f"transformer.h.{i}.attn.c_attn.weight"] = c_attn_w | ||
|
|
||
| q_b = tf.reshape(q_b, [-1]) | ||
| k_b = tf.reshape(k_b, [-1]) | ||
| v_b = tf.reshape(v_b, [-1]) | ||
|
|
||
| c_attn_b = tf.concat([q_b, k_b, v_b], axis=-1) | ||
| weights_map[f"transformer.h.{i}.attn.c_attn.bias"] = c_attn_b | ||
|
|
||
| # Attention projection | ||
| c_proj_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._output_dense.kernel | ||
| c_proj_w = tf.reshape( | ||
| c_proj_w, (keras_model.hidden_dim, keras_model.hidden_dim) | ||
| ) | ||
| weights_map[f"transformer.h.{i}.attn.c_proj.weight"] = c_proj_w | ||
| weights_map[f"transformer.h.{i}.attn.c_proj.bias"] = ( | ||
| keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer._output_dense.bias | ||
| ) | ||
|
|
||
| # Layer norms | ||
| weights_map[f"transformer.h.{i}.ln_1.weight"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer_norm.gamma | ||
| weights_map[f"transformer.h.{i}.ln_1.bias"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._self_attention_layer_norm.beta | ||
| weights_map[f"transformer.h.{i}.ln_2.weight"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_layer_norm.gamma | ||
| weights_map[f"transformer.h.{i}.ln_2.bias"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_layer_norm.beta | ||
|
|
||
| # MLP | ||
| c_fc_w = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_intermediate_dense.kernel | ||
| weights_map[f"transformer.h.{i}.mlp.c_fc.weight"] = c_fc_w | ||
| weights_map[f"transformer.h.{i}.mlp.c_fc.bias"] = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_intermediate_dense.bias | ||
| c_proj_w_mlp = keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_output_dense.kernel | ||
| weights_map[f"transformer.h.{i}.mlp.c_proj.weight"] = c_proj_w_mlp | ||
| weights_map[f"transformer.h.{i}.mlp.c_proj.bias"] = ( | ||
| keras_model.get_layer( | ||
| f"transformer_layer_{i}" | ||
| )._feedforward_output_dense.bias | ||
| ) | ||
|
|
||
| # Final layer norm | ||
| weights_map["transformer.ln_f.weight"] = keras_model.get_layer( | ||
| "layer_norm" | ||
| ).gamma | ||
| weights_map["transformer.ln_f.bias"] = keras_model.get_layer( | ||
| "layer_norm" | ||
| ).beta | ||
|
|
||
| if include_lm_head: | ||
| # lm_head is tied to token embeddings | ||
| weights_map["lm_head.weight"] = weights_map["transformer.wte.weight"] | ||
|
|
||
| return weights_map | ||
|
|
||
|
|
||
| def get_gpt2_tokenizer_config(tokenizer): | ||
| return { | ||
| "model_type": "gpt2", | ||
| "bos_token": "<|endoftext|>", | ||
| "eos_token": "<|endoftext|>", | ||
| "unk_token": "<|endoftext|>", | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| import os | ||
| import shutil | ||
| import sys | ||
| import tempfile | ||
| from os.path import abspath | ||
| from os.path import dirname | ||
|
|
||
| # import keras | ||
| import numpy as np | ||
| import tensorflow as tf | ||
LakshmiKalaKadali marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # import torch | ||
| from absl.testing import parameterized | ||
| from transformers import AutoModelForCausalLM | ||
| from transformers import AutoTokenizer | ||
|
|
||
| # Add the project root to the Python path. | ||
| sys.path.insert( | ||
| 0, dirname(dirname(dirname(dirname(dirname(abspath(__file__)))))) | ||
| ) | ||
|
|
||
| from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM | ||
| from keras_hub.src.utils.transformers.export.hf_exporter import ( | ||
| export_to_safetensors, | ||
| ) | ||
|
|
||
|
|
||
| def to_numpy(x): | ||
| # Torch tensor | ||
| if hasattr(x, "detach") and hasattr(x, "cpu"): | ||
| return x.detach().cpu().numpy() | ||
|
|
||
| # TF tensor | ||
| if hasattr(x, "numpy"): | ||
| return x.numpy() | ||
|
|
||
| # Numpy | ||
| if isinstance(x, np.ndarray): | ||
| return x | ||
|
|
||
| # KerasTensor or ragged wrapper → convert to TF → numpy | ||
| try: | ||
| import tensorflow as tf | ||
|
|
||
| return tf.convert_to_tensor(x).numpy() | ||
| except Exception: | ||
| pass | ||
|
|
||
| raise TypeError(f"Cannot convert value of type {type(x)} to numpy") | ||
|
|
||
|
|
||
| class GPT2ExportTest(tf.test.TestCase, parameterized.TestCase): | ||
| @parameterized.named_parameters( | ||
| ("gpt2_base_en", "gpt2_base_en"), | ||
| ) | ||
| def test_gpt2_export(self, preset): | ||
| # Create a temporary directory to save the converted model. | ||
| temp_dir = tempfile.mkdtemp() | ||
| output_path = os.path.join(temp_dir, preset) | ||
|
|
||
| # Load Keras model. | ||
| keras_model = GPT2CausalLM.from_preset(preset) | ||
|
|
||
| # Export to Hugging Face format. | ||
| export_to_safetensors(keras_model, output_path) | ||
|
|
||
| # Load the converted model with Hugging Face Transformers. | ||
| hf_model = AutoModelForCausalLM.from_pretrained(output_path) | ||
| hf_tokenizer = AutoTokenizer.from_pretrained(output_path) | ||
|
|
||
| # Assertions for config parameters. | ||
| self.assertEqual( | ||
| keras_model.backbone.hidden_dim, hf_model.config.hidden_size | ||
| ) | ||
| self.assertEqual( | ||
| keras_model.backbone.num_layers, hf_model.config.n_layer | ||
| ) | ||
| self.assertEqual(keras_model.backbone.num_heads, hf_model.config.n_head) | ||
| self.assertEqual( | ||
| keras_model.backbone.intermediate_dim, hf_model.config.n_inner | ||
| ) | ||
| self.assertEqual( | ||
| keras_model.backbone.vocabulary_size, hf_model.config.vocab_size | ||
| ) | ||
| self.assertEqual( | ||
| keras_model.backbone.max_sequence_length, | ||
| hf_model.config.n_positions, | ||
| ) | ||
|
|
||
| # Test logits. | ||
| prompt = "Hello, my name is" | ||
| token_ids = tf.constant( | ||
| keras_model.preprocessor.tokenizer(tf.constant([prompt])) | ||
| ) | ||
LakshmiKalaKadali marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| padding_mask = tf.ones_like(token_ids, dtype=tf.int32) | ||
| keras_inputs = {"token_ids": token_ids, "padding_mask": padding_mask} | ||
| keras_logits = keras_model(keras_inputs) | ||
|
|
||
| hf_inputs = hf_tokenizer(prompt, return_tensors="pt") | ||
| hf_logits = hf_model(**hf_inputs).logits | ||
| print(hf_logits) | ||
|
|
||
| # Compare logits. | ||
| # Keras logits are (batch_size, sequence_length, vocab_size) | ||
| # HF logits are (batch_size, sequence_length, vocab_size) | ||
| # We need to convert Keras logits to numpy and then to torch tensor | ||
| # for comparison. | ||
|
|
||
| # Convert Keras logits (TF) -> numpy | ||
| keras_logits_np = to_numpy(keras_logits) | ||
|
|
||
| # Convert HF logits (Torch, possibly MPS) -> numpy | ||
| hf_logits_np = to_numpy(hf_logits) | ||
|
|
||
| self.assertAllClose(keras_logits_np, hf_logits_np, atol=1e-3, rtol=1e-3) | ||
|
|
||
| # Clean up the temporary directory. | ||
| shutil.rmtree(temp_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tf.test.main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.