Skip to content

Commit 37db80d

Browse files
validate quantized dtypes in test
1 parent a3b6f48 commit 37db80d

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

keras_hub/src/models/task_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_summary_without_preprocessor(self):
107107
model.summary(print_fn=lambda x, line_break=False: summary.append(x))
108108
self.assertNotRegex("\n".join(summary), "Preprocessor:")
109109

110-
@pytest.mark.large
110+
# @pytest.mark.large
111111
def test_save_to_preset_with_quantization(self):
112112
save_dir = self.get_temp_dir()
113113
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
@@ -135,6 +135,15 @@ def test_save_to_preset_with_quantization(self):
135135
# Try loading the model from preset directory.
136136
restored_task = TextClassifier.from_preset(save_dir, num_classes=2)
137137

138+
# Validate dtypes for quantized layers are in lower precision.
139+
for layer in restored_task._flatten_layers():
140+
if isinstance(layer, keras.layers.Dense) and layer.name != "logits":
141+
self.assertEqual(
142+
layer.kernel.dtype,
143+
"int8",
144+
f"{layer.name=} should be in lower precision (int8)",
145+
)
146+
138147
# Test whether inference works.
139148
data = ["the quick brown fox.", "the slow brown fox."]
140149

0 commit comments

Comments
 (0)