diff --git a/keras_hub/src/utils/transformers/export/gpt2.py b/keras_hub/src/utils/transformers/export/gpt2.py new file mode 100644 index 0000000000..5135bfaafa --- /dev/null +++ b/keras_hub/src/utils/transformers/export/gpt2.py @@ -0,0 +1,152 @@ +import keras.ops as ops +import transformers + + +def get_gpt2_config(keras_model): + """Convert Keras GPT-2 config to Hugging Face GPT2Config.""" + return transformers.GPT2Config( + vocab_size=keras_model.vocabulary_size, + n_positions=keras_model.max_sequence_length, + n_embd=keras_model.hidden_dim, + n_layer=keras_model.num_layers, + n_head=keras_model.num_heads, + n_inner=keras_model.intermediate_dim, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + ) + + +def get_gpt2_weights_map(keras_model, include_lm_head=False): + """Create a weights map for a given GPT-2 model.""" + weights_map = {} + + # Token and position embeddings + weights_map["transformer.wte.weight"] = keras_model.get_layer( + "token_embedding" + ).embeddings + weights_map["transformer.wpe.weight"] = keras_model.get_layer( + "position_embedding" + ).position_embeddings + + for i in range(keras_model.num_layers): + # Attention weights + # KerasHub uses Dense layers: + # kernel shape [hidden_dim, num_heads, key_dim] + # HF uses Conv1D: weight shape [hidden_dim, 3 * hidden_dim] + q_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.kernel + k_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.kernel + v_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.kernel + q_b = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.bias + k_b = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.bias + v_b = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.bias + + # Flatten the head dimensions to match HF Conv1D input + q_w = ops.reshape(q_w, (keras_model.hidden_dim, keras_model.hidden_dim)) + k_w = ops.reshape(k_w, (keras_model.hidden_dim, keras_model.hidden_dim)) + v_w = ops.reshape(v_w, (keras_model.hidden_dim, keras_model.hidden_dim)) + + # Concatenate Q, K, V + c_attn_w = ops.concatenate([q_w, k_w, v_w], axis=-1) + weights_map[f"transformer.h.{i}.attn.c_attn.weight"] = c_attn_w + + # Reshape biases + q_b = ops.reshape(q_b, [-1]) + k_b = ops.reshape(k_b, [-1]) + v_b = ops.reshape(v_b, [-1]) + + c_attn_b = ops.concatenate([q_b, k_b, v_b], axis=-1) + weights_map[f"transformer.h.{i}.attn.c_attn.bias"] = c_attn_b + + # Attention projection + c_proj_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.kernel + c_proj_w = ops.reshape( + c_proj_w, (keras_model.hidden_dim, keras_model.hidden_dim) + ) + weights_map[f"transformer.h.{i}.attn.c_proj.weight"] = c_proj_w + weights_map[f"transformer.h.{i}.attn.c_proj.bias"] = ( + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.bias + ) + + # Layer norms + weights_map[f"transformer.h.{i}.ln_1.weight"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer_norm.gamma + weights_map[f"transformer.h.{i}.ln_1.bias"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer_norm.beta + weights_map[f"transformer.h.{i}.ln_2.weight"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_layer_norm.gamma + weights_map[f"transformer.h.{i}.ln_2.bias"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_layer_norm.beta + + # MLP + c_fc_w = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.kernel + weights_map[f"transformer.h.{i}.mlp.c_fc.weight"] = c_fc_w + weights_map[f"transformer.h.{i}.mlp.c_fc.bias"] = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.bias + c_proj_w_mlp = keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.kernel + weights_map[f"transformer.h.{i}.mlp.c_proj.weight"] = c_proj_w_mlp + weights_map[f"transformer.h.{i}.mlp.c_proj.bias"] = ( + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.bias + ) + + # Final layer norm + weights_map["transformer.ln_f.weight"] = keras_model.get_layer( + "layer_norm" + ).gamma + weights_map["transformer.ln_f.bias"] = keras_model.get_layer( + "layer_norm" + ).beta + + if include_lm_head: + # lm_head is tied to token embeddings + weights_map["lm_head.weight"] = weights_map["transformer.wte.weight"] + + return weights_map + + +def get_gpt2_tokenizer_config(tokenizer): + return { + "model_type": "gpt2", + "bos_token": "<|endoftext|>", + "eos_token": "<|endoftext|>", + "unk_token": "<|endoftext|>", + } diff --git a/keras_hub/src/utils/transformers/export/gpt2_test.py b/keras_hub/src/utils/transformers/export/gpt2_test.py new file mode 100644 index 0000000000..146ef27c49 --- /dev/null +++ b/keras_hub/src/utils/transformers/export/gpt2_test.py @@ -0,0 +1,71 @@ +import os +import shutil +import tempfile + +import keras.ops as ops +from absl.testing import parameterized +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM +from keras_hub.src.tests.test_case import TestCase +from keras_hub.src.utils.transformers.export.hf_exporter import ( + export_to_safetensors, +) + + +class GPT2ExportTest(TestCase): + @parameterized.named_parameters( + ("gpt2_base_en", "gpt2_base_en"), + ) + def test_gpt2_export(self, preset): + # Create a temporary directory to save the converted model. + temp_dir = tempfile.mkdtemp() + output_path = os.path.join(temp_dir, preset) + + # Load Keras model. + keras_model = GPT2CausalLM.from_preset(preset) + + # Export to Hugging Face format. + export_to_safetensors(keras_model, output_path) + + # Load the converted model with Hugging Face Transformers. + hf_model = AutoModelForCausalLM.from_pretrained(output_path) + hf_tokenizer = AutoTokenizer.from_pretrained(output_path) + + # Assertions for config parameters. + self.assertEqual( + keras_model.backbone.hidden_dim, hf_model.config.hidden_size + ) + self.assertEqual( + keras_model.backbone.num_layers, hf_model.config.n_layer + ) + self.assertEqual(keras_model.backbone.num_heads, hf_model.config.n_head) + self.assertEqual( + keras_model.backbone.intermediate_dim, hf_model.config.n_inner + ) + self.assertEqual( + keras_model.backbone.vocabulary_size, hf_model.config.vocab_size + ) + self.assertEqual( + keras_model.backbone.max_sequence_length, + hf_model.config.n_positions, + ) + + # Test logits. + prompt = "Hello, my name is" + token_ids = ops.array(keras_model.preprocessor.tokenizer([prompt])) + padding_mask = ops.ones_like(token_ids, dtype="int32") + keras_inputs = {"token_ids": token_ids, "padding_mask": padding_mask} + keras_logits = keras_model(keras_inputs) + + hf_inputs = hf_tokenizer(prompt, return_tensors="pt") + hf_logits = hf_model(**hf_inputs).logits + + keras_logits_np = ops.convert_to_numpy(keras_logits) + hf_logits_np = hf_logits.detach().cpu().numpy() + + self.assertAllClose(keras_logits_np, hf_logits_np, atol=1e-3, rtol=1e-3) + + # Clean up the temporary directory. + shutil.rmtree(temp_dir) diff --git a/keras_hub/src/utils/transformers/export/hf_exporter.py b/keras_hub/src/utils/transformers/export/hf_exporter.py index 1593987ca9..df491766f9 100644 --- a/keras_hub/src/utils/transformers/export/hf_exporter.py +++ b/keras_hub/src/utils/transformers/export/hf_exporter.py @@ -4,25 +4,35 @@ import warnings import keras +import torch +# --- Import GPT2Tokenizer --- from keras_hub.src.utils.transformers.export.gemma import get_gemma_config from keras_hub.src.utils.transformers.export.gemma import ( get_gemma_tokenizer_config, ) from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map +from keras_hub.src.utils.transformers.export.gpt2 import get_gpt2_config +from keras_hub.src.utils.transformers.export.gpt2 import ( + get_gpt2_tokenizer_config, +) +from keras_hub.src.utils.transformers.export.gpt2 import get_gpt2_weights_map MODEL_CONFIGS = { "GemmaBackbone": get_gemma_config, + "GPT2Backbone": get_gpt2_config, # Add for future models, e.g., "MistralBackbone": get_mistral_config } MODEL_EXPORTERS = { "GemmaBackbone": get_gemma_weights_map, + "GPT2Backbone": get_gpt2_weights_map, # Add for future models, e.g., "MistralBackbone": get_mistral_weights_map } MODEL_TOKENIZER_CONFIGS = { "GemmaTokenizer": get_gemma_tokenizer_config, + "GPT2Tokenizer": get_gpt2_tokenizer_config, # Add for future models, e.g., "MistralTokenizer": # get_mistral_tokenizer_config } @@ -54,23 +64,54 @@ def export_backbone(backbone, path, include_lm_head=False): weights_dict = get_weights_fn(backbone, include_lm_head=include_lm_head) if not weights_dict: raise ValueError("No weights to save.") + # Save config os.makedirs(path, exist_ok=True) config_path = os.path.join(path, "config.json") + + # Handle Config Objects vs Dicts + config_to_save = hf_config + if hasattr(hf_config, "to_dict"): + config_to_save = hf_config.to_dict() + with open(config_path, "w") as f: - json.dump(hf_config, f) + json.dump(config_to_save, f, indent=2) + # Save weights based on backend weights_path = os.path.join(path, "model.safetensors") if backend == "torch": from safetensors.torch import save_file - weights_dict_contiguous = { - k: v.value.contiguous() if hasattr(v, "value") else v.contiguous() - for k, v in weights_dict.items() - } - save_file( - weights_dict_contiguous, weights_path, metadata={"format": "pt"} - ) + weights_dict_torch = {} + for k, v in weights_dict.items(): + tensor = v.value if hasattr(v, "value") else v + + if isinstance(tensor, torch.Tensor): + t = tensor.detach().to("cpu") + elif hasattr(tensor, "numpy"): + t = torch.tensor(tensor.numpy()) + elif hasattr(tensor, "__array__"): + t = torch.tensor(tensor) + else: + t = tensor + + if hasattr(t, "contiguous"): + t = t.contiguous() + + weights_dict_torch[k] = t + + # Handle Tied Weights (GPT-2) + if ( + "lm_head.weight" in weights_dict_torch + and "transformer.wte.weight" in weights_dict_torch + ): + wte = weights_dict_torch["transformer.wte.weight"] + lm = weights_dict_torch["lm_head.weight"] + if wte.data_ptr() == lm.data_ptr(): + weights_dict_torch["lm_head.weight"] = lm.clone().contiguous() + + save_file(weights_dict_torch, weights_path, metadata={"format": "pt"}) + elif backend == "tensorflow": from safetensors.tensorflow import save_file @@ -91,31 +132,39 @@ def export_tokenizer(tokenizer, path): path: str. Path to save the exported tokenizer. """ os.makedirs(path, exist_ok=True) + # Save tokenizer assets tokenizer.save_assets(path) + # Export tokenizer config tokenizer_type = tokenizer.__class__.__name__ if tokenizer_type not in MODEL_TOKENIZER_CONFIGS: raise ValueError( - "Export to Transformers format not implemented for {tokenizer_type}" + f"Export to Transformer format not implemented for {tokenizer_type}" ) get_tokenizer_config_fn = MODEL_TOKENIZER_CONFIGS[tokenizer_type] tokenizer_config = get_tokenizer_config_fn(tokenizer) tokenizer_config_path = os.path.join(path, "tokenizer_config.json") with open(tokenizer_config_path, "w") as f: json.dump(tokenizer_config, f, indent=4) - # Rename vocabulary file - vocab_spm_path = os.path.join(path, "vocabulary.spm") - tokenizer_model_path = os.path.join(path, "tokenizer.model") - if os.path.exists(vocab_spm_path): - shutil.move(vocab_spm_path, tokenizer_model_path) - else: - warnings.warn( - f"{vocab_spm_path} not found. Tokenizer may not load " - "correctly. Ensure that the tokenizer configuration " - "is correct and that the vocabulary file is present " - "in the original model." - ) + + # 2. Rename files to match Hugging Face expectations + if tokenizer_type == "GemmaTokenizer": + vocab_spm_path = os.path.join(path, "vocabulary.spm") + tokenizer_model_path = os.path.join(path, "tokenizer.model") + if os.path.exists(vocab_spm_path): + shutil.move(vocab_spm_path, tokenizer_model_path) + else: + warnings.warn(f"{vocab_spm_path} not found.") + + elif tokenizer_type == "GPT2Tokenizer": + # Rename vocabulary.json -> vocab.json + vocab_json_path = os.path.join(path, "vocabulary.json") + vocab_hf_path = os.path.join(path, "vocab.json") + if os.path.exists(vocab_json_path): + shutil.move(vocab_json_path, vocab_hf_path) + else: + warnings.warn(f"{vocab_json_path} not found.") def export_to_safetensors(keras_model, path):