Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions keras_hub/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,19 @@ def load_lora_weights(self, filepath):
layer.lora_kernel_a.assign(lora_kernel_a)
layer.lora_kernel_b.assign(lora_kernel_b)
store.close()

def export_to_transformers(self, path):
"""Export the backbone model to HuggingFace Transformers format.

This saves the backbone's configuration and weights in a format
compatible with HuggingFace Transformers. For unsupported model
architectures, a ValueError is raised.

Args:
path: str. Path to save the exported model.
"""
from keras_hub.src.utils.transformers.export.hf_exporter import (
export_backbone,
)

export_backbone(self, path)
38 changes: 38 additions & 0 deletions keras_hub/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.bert.bert_backbone import BertBackbone
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.preset_utils import CONFIG_FILE
Expand Down Expand Up @@ -105,3 +106,40 @@ def test_save_to_preset(self):
ref_out = backbone(data)
new_out = restored_backbone(data)
self.assertAllClose(ref_out, new_out)

def test_export_supported_model(self):
backbone_config = {
"vocabulary_size": 1000,
"num_layers": 2,
"num_query_heads": 4,
"num_key_value_heads": 1,
"hidden_dim": 512,
"intermediate_dim": 1024,
"head_dim": 128,
}
backbone = GemmaBackbone(**backbone_config)
export_path = os.path.join(self.get_temp_dir(), "export_backbone")
backbone.export_to_transformers(export_path)
# Basic check: config file exists
self.assertTrue(
os.path.exists(os.path.join(export_path, "config.json"))
)

def test_export_unsupported_model(self):
backbone_config = {
"vocabulary_size": 1000,
"num_layers": 2,
"num_query_heads": 4,
"num_key_value_heads": 1,
"hidden_dim": 512,
"intermediate_dim": 1024,
"head_dim": 128,
}

class UnsupportedBackbone(GemmaBackbone):
pass

backbone = UnsupportedBackbone(**backbone_config)
export_path = os.path.join(self.get_temp_dir(), "unsupported")
with self.assertRaises(ValueError):
backbone.export_to_transformers(export_path)
21 changes: 21 additions & 0 deletions keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,24 @@ def postprocess(x):
outputs = [postprocess(x) for x in outputs]

return self._normalize_generate_outputs(outputs, input_is_scalar)

def export_to_transformers(self, path):
"""Export the full CausalLM model to HuggingFace Transformers format.

This exports the trainable model, tokenizer, and configurations in a
format compatible with HuggingFace Transformers. For unsupported model
architectures, a ValueError is raised.

If the preprocessor is attached (default), both the trainable model and
tokenizer are exported. To export only the trainable model, set
`self.preprocessor = None` before calling this method, then export the
preprocessor separately via `preprocessor.export_to_transformers(path)`.

Args:
path: str. Path to save the exported model.
"""
from keras_hub.src.utils.transformers.export.hf_exporter import (
export_to_safetensors,
)

export_to_safetensors(self, path)
14 changes: 14 additions & 0 deletions keras_hub/src/models/causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,17 @@ def sequence_length(self, value):
self._sequence_length = value
if self.packer is not None:
self.packer.sequence_length = value

def export_to_transformers(self, path):
"""Export the preprocessor to HuggingFace Transformers format.

Args:
path: str. Path to save the exported preprocessor/tokenizer.
"""
if self.tokenizer is None:
raise ValueError("Preprocessor must have a tokenizer for export.")
from keras_hub.src.utils.transformers.export.hf_exporter import (
export_tokenizer,
)

export_tokenizer(self.tokenizer, path)
25 changes: 25 additions & 0 deletions keras_hub/src/models/causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os

import pytest

from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
GemmaCausalLMPreprocessor,
)
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
Expand Down Expand Up @@ -43,3 +49,22 @@ def test_from_preset_errors(self):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
GPT2CausalLMPreprocessor.from_preset("hf://spacy/en_core_web_sm")

def test_export_supported_preprocessor(self):
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")
tokenizer = GemmaTokenizer(proto=proto)
preprocessor = GemmaCausalLMPreprocessor(tokenizer=tokenizer)
export_path = os.path.join(self.get_temp_dir(), "export_preprocessor")
preprocessor.export_to_transformers(export_path)
# Basic check: tokenizer config exists
self.assertTrue(
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
)

def test_export_missing_tokenizer(self):
preprocessor = GemmaCausalLMPreprocessor(tokenizer=None)
export_path = os.path.join(
self.get_temp_dir(), "export_missing_tokenizer"
)
with self.assertRaises(ValueError):
preprocessor.export_to_transformers(export_path)
88 changes: 88 additions & 0 deletions keras_hub/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
from keras_hub.src.models.causal_lm import CausalLM
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
GemmaCausalLMPreprocessor,
)
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.preprocessor import Preprocessor
Expand Down Expand Up @@ -171,3 +177,85 @@ def test_save_to_preset_custom_backbone_and_preprocessor(self):
restored_task = ImageClassifier.from_preset(save_dir)
actual = restored_task.predict(batch)
self.assertAllClose(expected, actual)

def _create_gemma_for_export_tests(self):
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")
tokenizer = GemmaTokenizer(proto=proto)
backbone = GemmaBackbone(
vocabulary_size=tokenizer.vocabulary_size(),
num_layers=2,
num_query_heads=4,
num_key_value_heads=1,
hidden_dim=512,
intermediate_dim=1024,
head_dim=128,
)
preprocessor = GemmaCausalLMPreprocessor(tokenizer=tokenizer)
causal_lm = GemmaCausalLM(backbone=backbone, preprocessor=preprocessor)
return causal_lm, preprocessor

def test_export_attached(self):
causal_lm, _ = self._create_gemma_for_export_tests()
export_path = os.path.join(self.get_temp_dir(), "export_attached")
causal_lm.export_to_transformers(export_path)
# Basic check: config and tokenizer files exist
self.assertTrue(
os.path.exists(os.path.join(export_path, "config.json"))
)
self.assertTrue(
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
)

def test_export_attached_with_lm_head(self):
# Since attached export always includes lm_head=True, this test verifies
# the same but explicitly notes it for coverage.
causal_lm, _ = self._create_gemma_for_export_tests()
export_path = os.path.join(
self.get_temp_dir(), "export_attached_lm_head"
)
causal_lm.export_to_transformers(export_path)
# Basic check: config and tokenizer files exist
self.assertTrue(
os.path.exists(os.path.join(export_path, "config.json"))
)
self.assertTrue(
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
)

def test_export_detached(self):
causal_lm, preprocessor = self._create_gemma_for_export_tests()
export_path_backbone = os.path.join(
self.get_temp_dir(), "export_detached_backbone"
)
export_path_preprocessor = os.path.join(
self.get_temp_dir(), "export_detached_preprocessor"
)
original_preprocessor = causal_lm.preprocessor
causal_lm.preprocessor = None
causal_lm.export_to_transformers(export_path_backbone)
causal_lm.preprocessor = original_preprocessor
preprocessor.export_to_transformers(export_path_preprocessor)
# Basic check: backbone has config, no tokenizer; preprocessor has
# tokenizer config
self.assertTrue(
os.path.exists(os.path.join(export_path_backbone, "config.json"))
)
self.assertFalse(
os.path.exists(
os.path.join(export_path_backbone, "tokenizer_config.json")
)
)
self.assertTrue(
os.path.exists(
os.path.join(export_path_preprocessor, "tokenizer_config.json")
)
)

def test_export_missing_tokenizer(self):
causal_lm, preprocessor = self._create_gemma_for_export_tests()
preprocessor.tokenizer = None
export_path = os.path.join(
self.get_temp_dir(), "export_missing_tokenizer"
)
with self.assertRaises(ValueError):
causal_lm.export_to_transformers(export_path)
Binary file not shown.
Binary file not shown.
15 changes: 15 additions & 0 deletions keras_hub/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,18 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_tokenizer(cls, config_file, **kwargs)

def export_to_transformers(self, path):
"""Export the tokenizer to HuggingFace Transformers format.

This saves tokenizer assets in a format compatible with HuggingFace
Transformers.

Args:
path: str. Path to save the exported tokenizer.
"""
from keras_hub.src.utils.transformers.export.hf_exporter import (
export_tokenizer,
)

export_tokenizer(self, path)
22 changes: 22 additions & 0 deletions keras_hub/src/tokenizers/tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_hub.src.tests.test_case import TestCase
Expand Down Expand Up @@ -113,3 +114,24 @@ def test_save_to_preset(self, cls, preset_name, tokenizer_type):
# Check config class.
tokenizer_config = load_json(save_dir, TOKENIZER_CONFIG_FILE)
self.assertEqual(cls, check_config_class(tokenizer_config))

def test_export_supported_tokenizer(self):
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")
tokenizer = GemmaTokenizer(proto=proto)
export_path = os.path.join(self.get_temp_dir(), "export_tokenizer")
tokenizer.export_to_transformers(export_path)
# Basic check: tokenizer config exists
self.assertTrue(
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
)

def test_export_unsupported_tokenizer(self):
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")

class UnsupportedTokenizer(GemmaTokenizer):
pass

tokenizer = UnsupportedTokenizer(proto=proto)
export_path = os.path.join(self.get_temp_dir(), "unsupported_tokenizer")
with self.assertRaises(ValueError):
tokenizer.export_to_transformers(export_path)
53 changes: 49 additions & 4 deletions keras_hub/src/utils/transformers/export/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


def get_gemma_config(backbone):
token_embedding_layer = backbone.get_layer("token_embedding")
hf_config = {
"vocab_size": backbone.vocabulary_size,
"num_hidden_layers": backbone.num_layers,
Expand All @@ -11,11 +12,16 @@ def get_gemma_config(backbone):
"intermediate_size": backbone.intermediate_dim // 2,
"head_dim": backbone.head_dim,
"max_position_embeddings": 8192,
"tie_word_embeddings": token_embedding_layer.tie_weights,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
"model_type": "gemma",
}
return hf_config


def get_gemma_weights_map(backbone):
def get_gemma_weights_map(backbone, include_lm_head=False):
weights_dict = {}

# Map token embedding
Expand Down Expand Up @@ -83,7 +89,46 @@ def get_gemma_weights_map(backbone):
"final_normalization"
).weights[0]

# Tie weights, but clone to avoid sharing memory issues
weights_dict["lm_head.weight"] = ops.copy(token_embedding_layer.weights[0])

# Map lm_head if embeddings are not tied
if include_lm_head and not token_embedding_layer.tie_weights:
weights_dict["lm_head.weight"] = ops.transpose(
token_embedding_layer.reverse_embeddings
)
return weights_dict


def get_gemma_tokenizer_config(tokenizer):
tokenizer_config = {
"tokenizer_class": "GemmaTokenizer",
"clean_up_tokenization_spaces": False,
"bos_token": "<bos>",
"eos_token": "<eos>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"add_bos_token": True,
"add_eos_token": False,
"model_max_length": 8192,
}
# Add added_tokens_decoder
added_tokens_decoder = {}
special_tokens = [
"<pad>",
"<bos>",
"<eos>",
"<unk>",
"<start_of_turn>",
"<end_of_turn>",
]
for token in special_tokens:
token_id = tokenizer.token_to_id(token)
if token_id is not None:
added_tokens_decoder[str(token_id)] = {
"content": token,
"special": True,
"single_word": False,
"lstrip": False,
"rstrip": False,
"normalized": False,
}
tokenizer_config["added_tokens_decoder"] = added_tokens_decoder
return tokenizer_config
Loading
Loading