Skip to content
7 changes: 5 additions & 2 deletions keras_hub/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,24 @@ def get_config(self):
}

# Add quantization support by utilizing `DTypePolicyMap`
dtype = None
try:
if isinstance(
self.dtype_policy, keras.dtype_policies.DTypePolicyMap
):
config.update({"dtype": self.dtype_policy})
dtype = self.dtype_policy
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
dtype = policy_map
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
except AttributeError:
pass

config.update({"dtype": dtype})
return config

@classmethod
Expand Down
95 changes: 95 additions & 0 deletions keras_hub/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import keras
import numpy as np
import pytest
from absl.testing import parameterized

from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
from keras_hub.src.models.causal_lm import CausalLM
Expand Down Expand Up @@ -107,6 +108,100 @@ def test_summary_without_preprocessor(self):
model.summary(print_fn=lambda x, line_break=False: summary.append(x))
self.assertNotRegex("\n".join(summary), "Preprocessor:")

@pytest.mark.large
@parameterized.named_parameters(
{
"testcase_name": "load_with_quantized_weights",
"load_weights": True,
"dtype_override": None,
"expected_dtype": "int8",
},
{
"testcase_name": "override_dtype_without_loading_weights",
"load_weights": False,
"dtype_override": "float32",
"expected_dtype": "float32",
},
)
def test_quantized_preset_loading_and_saving(
self, load_weights, dtype_override, expected_dtype
):
# Create, quantize, and save the model preset.
save_dir = self.get_temp_dir()
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
task.quantize(mode="int8")
task.save_to_preset(save_dir)

# Verify that all necessary files were created.
path = pathlib.Path(save_dir)
self.assertTrue(os.path.exists(path / CONFIG_FILE))
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
self.assertTrue(os.path.exists(path / METADATA_FILE))
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))

# Verify the contents of the task config file.
task_config = load_json(save_dir, TASK_CONFIG_FILE)
self.assertNotIn("build_config", task_config)
self.assertNotIn("compile_config", task_config)
self.assertIn("backbone", task_config["config"])
self.assertIn("preprocessor", task_config["config"])
self.assertEqual(BertTextClassifier, check_config_class(task_config))

# Restore the task from the preset using parameterized arguments.
restored_task = TextClassifier.from_preset(
save_dir,
num_classes=2,
load_weights=load_weights,
dtype=dtype_override,
)

# Check that the layers have the expected data type.
for layer in restored_task._flatten_layers():
if isinstance(layer, keras.layers.Dense) and layer.name != "logits":
self.assertEqual(
layer.kernel.dtype,
expected_dtype,
f"Layer '{layer.name}' kernel "
"should have dtype '{expected_dtype}'",
)

# Ensure inference runs without errors.
data = ["the quick brown fox.", "the slow brown fox."]
_ = restored_task.predict(data)

@pytest.mark.large
def test_load_quantized_preset_with_dtype_override(self):
save_dir = self.get_temp_dir()
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
task.quantize(mode="int8")
task.save_to_preset(save_dir)

# Check existence of files.
path = pathlib.Path(save_dir)
self.assertTrue(os.path.exists(path / CONFIG_FILE))
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
self.assertTrue(os.path.exists(path / METADATA_FILE))
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))

# Check the task config (`task.json`).
task_config = load_json(save_dir, TASK_CONFIG_FILE)
self.assertTrue("build_config" not in task_config)
self.assertTrue("compile_config" not in task_config)
self.assertTrue("backbone" in task_config["config"])
self.assertTrue("preprocessor" in task_config["config"])

# Check the preset directory task class.
self.assertEqual(BertTextClassifier, check_config_class(task_config))

# Loading the model in full-precision should cause an error during
# initialization. The serialized quantized layers include additional
# quantization specific weights (kernel_scale, etc.) which the
# full-precision layer is not aware about and can't handle.
with self.assertRaises(ValueError):
TextClassifier.from_preset(save_dir, num_classes=2, dtype="float32")

@pytest.mark.large
def test_save_to_preset(self):
save_dir = self.get_temp_dir()
Expand Down
50 changes: 50 additions & 0 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from absl import logging

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.utils import tensor_utils
from keras_hub.src.utils.keras_utils import print_msg
from keras_hub.src.utils.keras_utils import sharded_weights_available
from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
Expand Down Expand Up @@ -687,6 +688,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
)
# We found a `task.json` with a complete config for our class.
# Forward backbone args.
kwargs["dtype"] = self._resolve_dtype(self.config, kwargs)
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
if "backbone" in task_config["config"]:
backbone_config = task_config["config"]["backbone"]["config"]
Expand All @@ -708,6 +710,54 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
self._load_backbone_weights(task.backbone)
return task

def _resolve_dtype(self, config, kwargs):
"""Resolves the Model's dtype based on the provided config and kwargs.

The data type is resolved based on the following priority:
1. If a user specified dtype is passed, use that.
2. If no user specified dtype is passed, and the save dtype is castable
to the current keras default dtype convert weights on load (float type
to float type).
3. If not user specified dtype is passed, and the save dtype is not
castable to the current default dtype (quantized dtypes). Load the
saved types verbatim.

Args:
config: The model configuration.
kwargs: Additional keyword arguments, potentially including `dtype`.

Returns:
The resolved dtype.
"""
# 1. If a user specified dtype is passed, use that.
if "dtype" in kwargs and kwargs["dtype"] is not None:
return kwargs["dtype"]

saved_dtype = config.get("config", {}).get("dtype")

# If there's no saved dtype, we don't need to do anything.
if saved_dtype is None:
return None

# If the saved dtype is a string (e.g. "float32"), check if it is a
# floating point type.
is_float = isinstance(saved_dtype, str) and tensor_utils.is_float_dtype(
saved_dtype
)
if is_float:
# 2. If the saved dtype is a float, we can safely cast to the
# default backend float type.
logging.info(
"No dtype specified during loading. "
f"Using {keras.config.dtype_policy} as default. "
"This may result in type casting."
)
return keras.config.dtype_policy
else:
# 3. Otherwise, the dtype is a complex object (e.g. a
# DTypePolicyMap for quantization), and should be used as is.
return saved_dtype

def load_preprocessor(
self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
):
Expand Down
Loading