Skip to content

Commit

Permalink
Fix serialization / deserialization.
Browse files Browse the repository at this point in the history
- Serialization was not taking the registered name and package from the registry.
- Deserialization was selecting symbols by postfix as a fallback.

PiperOrigin-RevId: 691149084
  • Loading branch information
hertschuh authored and tensorflower-gardener committed Oct 29, 2024
1 parent c512b1a commit 2da1800
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
10 changes: 6 additions & 4 deletions tf_keras/saving/object_registration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,23 @@ def from_config(cls, config):
self.assertEqual(5, new_inst._val)

def test_serialize_custom_function(self):
@object_registration.register_keras_serializable()
@object_registration.register_keras_serializable(
package="Test", name="func"
)
def my_fn():
return 42

serialized_name = "Custom>my_fn"
serialized_name = "Test>func"
class_name = object_registration._GLOBAL_CUSTOM_NAMES[my_fn]
self.assertEqual(serialized_name, class_name)
fn_class_name = object_registration.get_registered_name(my_fn)
self.assertEqual(fn_class_name, class_name)

config = serialization_lib.serialize_keras_object(my_fn)
if tf.__internal__.tf2.enabled():
self.assertEqual("my_fn", config["config"])
self.assertEqual(serialized_name, config["config"])
else:
self.assertEqual(class_name, config)
self.assertEqual(serialized_name, config)
fn = serialization_lib.deserialize_keras_object(config)
self.assertEqual(42, fn())

Expand Down
11 changes: 1 addition & 10 deletions tf_keras/saving/serialization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _get_class_or_fn_config(obj):
"""Return the object's config depending on its type."""
# Functions / lambdas:
if isinstance(obj, types.FunctionType):
return obj.__name__
return object_registration.get_registered_name(obj)
# All classes:
if hasattr(obj, "get_config"):
config = obj.get_config()
Expand Down Expand Up @@ -789,15 +789,6 @@ def _retrieve_class_or_fn(
if obj is not None:
return obj

# Retrieval of registered custom function in a package
filtered_dict = {
k: v
for k, v in custom_objects.items()
if k.endswith(full_config["config"])
}
if filtered_dict:
return next(iter(filtered_dict.values()))

# Otherwise, attempt to retrieve the class object given the `module`
# and `class_name`. Import the module, find the class.
try:
Expand Down

0 comments on commit 2da1800

Please sign in to comment.