From 1b07517d1177d5aa97a880e879fd0b27f062b3f3 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:09:26 +0530 Subject: [PATCH 01/10] ensures DTypePolicyMap is added to backbone kwargs during load_task --- keras_hub/src/utils/preset_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 340dfcbe57..f7287386ed 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -687,6 +687,8 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): ) # We found a `task.json` with a complete config for our class. # Forward backbone args. + if "config" in self.config and "dtype" in self.config["config"]: + kwargs["dtype"] = self.config["config"]["dtype"] backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs) if "backbone" in task_config["config"]: backbone_config = task_config["config"]["backbone"]["config"] From e5eff0ab3307bdcbee6d842f4d3716c3872e5ed1 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 20 Aug 2025 13:18:58 +0530 Subject: [PATCH 02/10] Added test for loading quantized presets --- keras_hub/src/models/task_test.py | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index b46e46b361..5fb727a7ab 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -107,6 +107,39 @@ def test_summary_without_preprocessor(self): model.summary(print_fn=lambda x, line_break=False: summary.append(x)) self.assertNotRegex("\n".join(summary), "Preprocessor:") + # @pytest.mark.large + def test_save_to_preset_with_quantization(self): + save_dir = self.get_temp_dir() + task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) + task.quantize(mode="int8") + task.save_to_preset(save_dir) + + # Check existence of files. + path = pathlib.Path(save_dir) + self.assertTrue(os.path.exists(path / CONFIG_FILE)) + self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE)) + self.assertTrue(os.path.exists(path / METADATA_FILE)) + self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE)) + self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE)) + + # Check the task config (`task.json`). + task_config = load_json(save_dir, TASK_CONFIG_FILE) + self.assertTrue("build_config" not in task_config) + self.assertTrue("compile_config" not in task_config) + self.assertTrue("backbone" in task_config["config"]) + self.assertTrue("preprocessor" in task_config["config"]) + + # Check the preset directory task class. + self.assertEqual(BertTextClassifier, check_config_class(task_config)) + + # Try loading the model from preset directory. + restored_task = TextClassifier.from_preset(save_dir, num_classes=2) + + # Test whether inference works. + data = ["the quick brown fox.", "the slow brown fox."] + + _ = restored_task.predict(data) + @pytest.mark.large def test_save_to_preset(self): save_dir = self.get_temp_dir() From a3b6f48042dc3c515ba5896a1d9eac018159aa19 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 20 Aug 2025 13:20:17 +0530 Subject: [PATCH 03/10] marks test as large --- keras_hub/src/models/task_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index 5fb727a7ab..f80ff90386 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -107,7 +107,7 @@ def test_summary_without_preprocessor(self): model.summary(print_fn=lambda x, line_break=False: summary.append(x)) self.assertNotRegex("\n".join(summary), "Preprocessor:") - # @pytest.mark.large + @pytest.mark.large def test_save_to_preset_with_quantization(self): save_dir = self.get_temp_dir() task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) From 37db80df35498312ed9256588abb68f4bfa1da33 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 20 Aug 2025 13:38:08 +0530 Subject: [PATCH 04/10] validate quantized dtypes in test --- keras_hub/src/models/task_test.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index f80ff90386..28489c8471 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -107,7 +107,7 @@ def test_summary_without_preprocessor(self): model.summary(print_fn=lambda x, line_break=False: summary.append(x)) self.assertNotRegex("\n".join(summary), "Preprocessor:") - @pytest.mark.large + # @pytest.mark.large def test_save_to_preset_with_quantization(self): save_dir = self.get_temp_dir() task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) @@ -135,6 +135,15 @@ def test_save_to_preset_with_quantization(self): # Try loading the model from preset directory. restored_task = TextClassifier.from_preset(save_dir, num_classes=2) + # Validate dtypes for quantized layers are in lower precision. + for layer in restored_task._flatten_layers(): + if isinstance(layer, keras.layers.Dense) and layer.name != "logits": + self.assertEqual( + layer.kernel.dtype, + "int8", + f"{layer.name=} should be in lower precision (int8)", + ) + # Test whether inference works. data = ["the quick brown fox.", "the slow brown fox."] From 84f26d2738bafaed8b08940b46e02119ef2dc48e Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 20 Aug 2025 13:44:27 +0530 Subject: [PATCH 05/10] add comments --- keras_hub/src/utils/preset_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index f7287386ed..2b2d62da48 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -688,6 +688,8 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): # We found a `task.json` with a complete config for our class. # Forward backbone args. if "config" in self.config and "dtype" in self.config["config"]: + # Forward the serialized dtype from the config. This is critical for + # restoring quantized models, which rely on a `DTypePolicyMap`. kwargs["dtype"] = self.config["config"]["dtype"] backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs) if "backbone" in task_config["config"]: From 58dfab9fbc503e2cd2a5eb2bd4d331c1b88dd8d5 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Thu, 21 Aug 2025 13:46:02 +0530 Subject: [PATCH 06/10] implements priority-based dtype resolution + tests --- keras_hub/src/models/backbone.py | 7 ++- keras_hub/src/models/task_test.py | 87 +++++++++++++++++++++++------ keras_hub/src/utils/preset_utils.py | 54 ++++++++++++++++-- 3 files changed, 125 insertions(+), 23 deletions(-) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 55aaec239d..7a8545e0af 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -91,21 +91,24 @@ def get_config(self): } # Add quantization support by utilizing `DTypePolicyMap` + dtype = None try: if isinstance( self.dtype_policy, keras.dtype_policies.DTypePolicyMap ): - config.update({"dtype": self.dtype_policy}) + dtype = self.dtype_policy else: policy_map = keras.dtype_policies.DTypePolicyMap() for layer in self._flatten_layers(): if layer.quantization_mode is not None: policy_map[layer.path] = layer.dtype_policy if len(policy_map) > 0: - config.update({"dtype": policy_map}) + dtype = policy_map # Before Keras 3.2, there is no `keras.dtype_policies.get`. except AttributeError: pass + + config.update({"dtype": dtype}) return config @classmethod diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index 28489c8471..6e0990d908 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -4,6 +4,7 @@ import keras import numpy as np import pytest +from absl.testing import parameterized from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.causal_lm import CausalLM @@ -107,14 +108,31 @@ def test_summary_without_preprocessor(self): model.summary(print_fn=lambda x, line_break=False: summary.append(x)) self.assertNotRegex("\n".join(summary), "Preprocessor:") - # @pytest.mark.large - def test_save_to_preset_with_quantization(self): + @pytest.mark.large + @parameterized.named_parameters( + { + "testcase_name": "load_with_quantized_weights", + "load_weights": True, + "dtype_override": None, + "expected_dtype": "int8", + }, + { + "testcase_name": "override_dtype_without_loading_weights", + "load_weights": False, + "dtype_override": "float32", + "expected_dtype": "float32", + }, + ) + def test_quantized_preset_loading_and_saving( + self, load_weights, dtype_override, expected_dtype + ): + # Create, quantize, and save the model preset. save_dir = self.get_temp_dir() task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) task.quantize(mode="int8") task.save_to_preset(save_dir) - # Check existence of files. + # Verify that all necessary files were created. path = pathlib.Path(save_dir) self.assertTrue(os.path.exists(path / CONFIG_FILE)) self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE)) @@ -122,33 +140,68 @@ def test_save_to_preset_with_quantization(self): self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE)) self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE)) - # Check the task config (`task.json`). + # Verify the contents of the task config file. task_config = load_json(save_dir, TASK_CONFIG_FILE) - self.assertTrue("build_config" not in task_config) - self.assertTrue("compile_config" not in task_config) - self.assertTrue("backbone" in task_config["config"]) - self.assertTrue("preprocessor" in task_config["config"]) - - # Check the preset directory task class. + self.assertNotIn("build_config", task_config) + self.assertNotIn("compile_config", task_config) + self.assertIn("backbone", task_config["config"]) + self.assertIn("preprocessor", task_config["config"]) self.assertEqual(BertTextClassifier, check_config_class(task_config)) - # Try loading the model from preset directory. - restored_task = TextClassifier.from_preset(save_dir, num_classes=2) + # Restore the task from the preset using parameterized arguments. + restored_task = TextClassifier.from_preset( + save_dir, + num_classes=2, + load_weights=load_weights, + dtype=dtype_override, + ) - # Validate dtypes for quantized layers are in lower precision. + # Check that the layers have the expected data type. for layer in restored_task._flatten_layers(): if isinstance(layer, keras.layers.Dense) and layer.name != "logits": self.assertEqual( layer.kernel.dtype, - "int8", - f"{layer.name=} should be in lower precision (int8)", + expected_dtype, + f"Layer '{layer.name}' kernel " + "should have dtype '{expected_dtype}'", ) - # Test whether inference works. + # Ensure inference runs without errors. data = ["the quick brown fox.", "the slow brown fox."] - _ = restored_task.predict(data) + @pytest.mark.large + def test_load_quantized_preset_with_dtype_override(self): + save_dir = self.get_temp_dir() + task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) + task.quantize(mode="int8") + task.save_to_preset(save_dir) + + # Check existence of files. + path = pathlib.Path(save_dir) + self.assertTrue(os.path.exists(path / CONFIG_FILE)) + self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE)) + self.assertTrue(os.path.exists(path / METADATA_FILE)) + self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE)) + self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE)) + + # Check the task config (`task.json`). + task_config = load_json(save_dir, TASK_CONFIG_FILE) + self.assertTrue("build_config" not in task_config) + self.assertTrue("compile_config" not in task_config) + self.assertTrue("backbone" in task_config["config"]) + self.assertTrue("preprocessor" in task_config["config"]) + + # Check the preset directory task class. + self.assertEqual(BertTextClassifier, check_config_class(task_config)) + + # Loading the model in full-precision should cause an error during + # initialization. The serialized quantized layers include additional + # quantization specific weights (kernel_scale, etc.) which the + # full-precision layer is not aware about and can't handle. + with self.assertRaises(ValueError): + TextClassifier.from_preset(save_dir, num_classes=2, dtype="float32") + @pytest.mark.large def test_save_to_preset(self): save_dir = self.get_temp_dir() diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 2b2d62da48..ccedc32cdc 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -10,6 +10,7 @@ from absl import logging from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.utils import tensor_utils from keras_hub.src.utils.keras_utils import print_msg from keras_hub.src.utils.keras_utils import sharded_weights_available from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits @@ -687,10 +688,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): ) # We found a `task.json` with a complete config for our class. # Forward backbone args. - if "config" in self.config and "dtype" in self.config["config"]: - # Forward the serialized dtype from the config. This is critical for - # restoring quantized models, which rely on a `DTypePolicyMap`. - kwargs["dtype"] = self.config["config"]["dtype"] + kwargs["dtype"] = self._resolve_dtype(self.config, kwargs) backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs) if "backbone" in task_config["config"]: backbone_config = task_config["config"]["backbone"]["config"] @@ -712,6 +710,54 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): self._load_backbone_weights(task.backbone) return task + def _resolve_dtype(self, config, kwargs): + """Resolves the Model's dtype based on the provided config and kwargs. + + The data type is resolved based on the following priority: + 1. If a user specified dtype is passed, use that. + 2. If no user specified dtype is passed, and the save dtype is castable + to the current keras default dtype convert weights on load (float type + to float type). + 3. If not user specified dtype is passed, and the save dtype is not + castable to the current default dtype (quantized dtypes). Load the + saved types verbatim. + + Args: + config: The model configuration. + kwargs: Additional keyword arguments, potentially including `dtype`. + + Returns: + The resolved dtype. + """ + # 1. If a user specified dtype is passed, use that. + if "dtype" in kwargs and kwargs["dtype"] is not None: + return kwargs["dtype"] + + saved_dtype = config.get("config", {}).get("dtype") + + # If there's no saved dtype, we don't need to do anything. + if saved_dtype is None: + return None + + # If the saved dtype is a string (e.g. "float32"), check if it is a + # floating point type. + is_float = isinstance(saved_dtype, str) and tensor_utils.is_float_dtype( + saved_dtype + ) + if is_float: + # 2. If the saved dtype is a float, we can safely cast to the + # default backend float type. + logging.info( + "No dtype specified during loading. " + f"Using {keras.config.dtype_policy} as default. " + "This may result in type casting." + ) + return keras.config.dtype_policy + else: + # 3. Otherwise, the dtype is a complex object (e.g. a + # DTypePolicyMap for quantization), and should be used as is. + return saved_dtype + def load_preprocessor( self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs ): From ef053d6e0ecd27eebe612857727940d10b8d6f2c Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Sat, 23 Aug 2025 14:48:57 +0530 Subject: [PATCH 07/10] Fixes float check + improves logging --- keras_hub/src/models/backbone.py | 24 ++++++++---------------- keras_hub/src/utils/preset_utils.py | 26 ++++++++++++-------------- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 7a8545e0af..af2d57ec08 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -91,22 +91,14 @@ def get_config(self): } # Add quantization support by utilizing `DTypePolicyMap` - dtype = None - try: - if isinstance( - self.dtype_policy, keras.dtype_policies.DTypePolicyMap - ): - dtype = self.dtype_policy - else: - policy_map = keras.dtype_policies.DTypePolicyMap() - for layer in self._flatten_layers(): - if layer.quantization_mode is not None: - policy_map[layer.path] = layer.dtype_policy - if len(policy_map) > 0: - dtype = policy_map - # Before Keras 3.2, there is no `keras.dtype_policies.get`. - except AttributeError: - pass + dtype = self.dtype_policy + if not isinstance(dtype, keras.dtype_policies.DTypePolicyMap): + policy_map = keras.dtype_policies.DTypePolicyMap() + for layer in self._flatten_layers(): + if layer.quantization_mode is not None: + policy_map[layer.path] = layer.dtype_policy + if len(policy_map) > 0: + dtype = policy_map config.update({"dtype": dtype}) return config diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index ccedc32cdc..b392ee1580 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -739,20 +739,18 @@ def _resolve_dtype(self, config, kwargs): if saved_dtype is None: return None - # If the saved dtype is a string (e.g. "float32"), check if it is a - # floating point type. - is_float = isinstance(saved_dtype, str) and tensor_utils.is_float_dtype( - saved_dtype - ) - if is_float: - # 2. If the saved dtype is a float, we can safely cast to the - # default backend float type. - logging.info( - "No dtype specified during loading. " - f"Using {keras.config.dtype_policy} as default. " - "This may result in type casting." - ) - return keras.config.dtype_policy + # 2. Check whether the saved dtype is a simple float type. + policy_name = saved_dtype.get("config", {}).get("name") + if policy_name and tensor_utils.is_float_dtype(policy_name): + # If the saved dtype is a float, we can safely cast to the default + # backend float type. + if policy_name != keras.config.dtype_policy().name: + logging.info( + f"Converting weights saved as {policy_name} " + "to the current Keras dtype policy " + f"{keras.config.dtype_policy()}" + ) + return keras.config.dtype_policy() else: # 3. Otherwise, the dtype is a complex object (e.g. a # DTypePolicyMap for quantization), and should be used as is. From 6b45f052b58d0eccf6027fbe318181108a005d80 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Sun, 24 Aug 2025 08:29:06 +0530 Subject: [PATCH 08/10] Fixes dtype serialization --- keras_hub/src/models/backbone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index af2d57ec08..6e5f36afaa 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -100,7 +100,7 @@ def get_config(self): if len(policy_map) > 0: dtype = policy_map - config.update({"dtype": dtype}) + config.update({"dtype": keras.dtype_policies.serialize(dtype)}) return config @classmethod From 7eb8f1e2756af01547fe3169cfd2b1c46ffcbc09 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Tue, 26 Aug 2025 09:41:06 +0530 Subject: [PATCH 09/10] improves float check + adds tests --- keras_hub/src/models/task_test.py | 2 +- keras_hub/src/utils/preset_utils.py | 7 ++++--- keras_hub/src/utils/tensor_utils.py | 24 +++++++++++++++++++++- keras_hub/src/utils/tensor_utils_test.py | 26 ++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index 6e0990d908..b4196887bf 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -163,7 +163,7 @@ def test_quantized_preset_loading_and_saving( layer.kernel.dtype, expected_dtype, f"Layer '{layer.name}' kernel " - "should have dtype '{expected_dtype}'", + f"should have dtype '{expected_dtype}'", ) # Ensure inference runs without errors. diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index b392ee1580..affb2362c6 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -723,11 +723,12 @@ def _resolve_dtype(self, config, kwargs): saved types verbatim. Args: - config: The model configuration. - kwargs: Additional keyword arguments, potentially including `dtype`. + config: dict. The model configuration. + kwargs: dict. Additional keyword arguments, potentially including + `dtype`. Returns: - The resolved dtype. + str, dict, or DTypePolicy. The resolved dtype. """ # 1. If a user specified dtype is passed, use that. if "dtype" in kwargs and kwargs["dtype"] is not None: diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index 47305a3f01..2bc3f49199 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -310,7 +310,29 @@ def is_tensor_type(x): def is_float_dtype(dtype): - return "float" in keras.backend.standardize_dtype(dtype) + """ + Checks if a dtype is a float type by using a regex. + + This function standardizes the input dtype and then uses a regular + expression to perform an exact match. It identifies standard floats, + bfloats, and mixed-precision float types. + + For example: + - `is_float_dtype("float32")` returns `True`. + - `is_float_dtype("bfloat16")` returns `True`. + - `is_float_dtype("mixed_float16")` returns `True`. + - `is_float_dtype("int8")` returns `False`. + - `is_float_dtype("int8_from_float32")` returns `False`. + + Args: + dtype: str, DTypePolicy. The data type to check. + + Returns: + bool: `True` if the dtype is a floating-point type, `False` otherwise. + """ + pattern = re.compile(r"^(mixed_)?(b)?float[0-9]*$") + standardized_dtype = keras.backend.standardize_dtype(dtype) + return pattern.match(standardized_dtype) is not None def is_int_dtype(dtype): diff --git a/keras_hub/src/utils/tensor_utils_test.py b/keras_hub/src/utils/tensor_utils_test.py index 0b6ef1f346..350a97e501 100644 --- a/keras_hub/src/utils/tensor_utils_test.py +++ b/keras_hub/src/utils/tensor_utils_test.py @@ -8,6 +8,7 @@ from keras_hub.src.utils.tensor_utils import convert_preprocessing_inputs from keras_hub.src.utils.tensor_utils import convert_preprocessing_outputs from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch +from keras_hub.src.utils.tensor_utils import is_float_dtype from keras_hub.src.utils.tensor_utils import is_tensor_type from keras_hub.src.utils.tensor_utils import preprocessing_function from keras_hub.src.utils.tensor_utils import target_gather @@ -304,3 +305,28 @@ def test_target_gather_invalid_rank(self): indices = np.array([0, 1], dtype="int32") with self.assertRaisesRegex(ValueError, "larger than 3"): _ = target_gather(targets, indices) + + +class IsFloatDtypeTest(TestCase): + def test_float_dtypes_return_true(self): + float_dtypes = [ + "float16", + "float32", + "float64", + "bfloat16", + ] + for dtype in float_dtypes: + self.assertTrue(is_float_dtype(dtype)) + + def test_non_float_dtypes_return_false(self): + non_float_dtypes = [ + "int8", + "int32", + "uint8", + "bool", + "string", + "int8_from_float32", + "int4_from_bfloat16", + ] + for dtype in non_float_dtypes: + self.assertFalse(is_float_dtype(dtype)) From aa78876975d2daa8e4467ad6f7dd7ea2b4413bf0 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Tue, 26 Aug 2025 15:19:28 +0530 Subject: [PATCH 10/10] removes types not supported by standardize_dtypes --- keras_hub/src/utils/tensor_utils_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_hub/src/utils/tensor_utils_test.py b/keras_hub/src/utils/tensor_utils_test.py index 350a97e501..fba2b35c76 100644 --- a/keras_hub/src/utils/tensor_utils_test.py +++ b/keras_hub/src/utils/tensor_utils_test.py @@ -325,8 +325,6 @@ def test_non_float_dtypes_return_false(self): "uint8", "bool", "string", - "int8_from_float32", - "int4_from_bfloat16", ] for dtype in non_float_dtypes: self.assertFalse(is_float_dtype(dtype))