Skip to content

Commit

Permalink
chore: make sure AnySerializer is instantiated before wrapping classes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580564310
  • Loading branch information
jaycee-li authored and copybara-github committed Nov 8, 2023
1 parent 6f40f1b commit 21686ae
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/unit/vertexai/test_any_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
from typing import Any

import vertexai
from vertexai.preview import developer
from vertexai.preview._workflow.serialization_engine import (
any_serializer,
Expand Down Expand Up @@ -323,6 +324,29 @@ def test_any_serializer_register_predefined_serializers(self, caplog):
serializer_instance._serialization_scheme == _TEST_SERIALIZATION_SCHEME
)

def test_any_serializer_with_wrapped_class(self):
# Reset the serializer instances
serializers_base.Serializer._instances = {}

# Wrap a ML class that we have predefined serializer
unwrapped_keras_class = keras.models.Model
keras.models.Model = vertexai.preview.remote(keras.models.Model)

try:
# Assert that AnySerializer still registered the original class
serializer_instance = any_serializer.AnySerializer()
assert keras.models.Model not in serializer_instance._serialization_scheme
assert unwrapped_keras_class in serializer_instance._serialization_scheme
assert (
serializer_instance._serialization_scheme[unwrapped_keras_class]
== serializers.KerasModelSerializer
)
except Exception as e:
raise e
finally:
# Revert the class after testing
keras.models.Model = unwrapped_keras_class

def test_any_serializer_global_metadata_created(
self, mock_cloudpickle_serialize, any_serializer_instance, tmp_path
):
Expand Down
7 changes: 7 additions & 0 deletions vertexai/preview/_workflow/driver/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from vertexai.preview._workflow.executor import (
training,
)
from vertexai.preview._workflow.serialization_engine import (
any_serializer,
)
from vertexai.preview._workflow.shared import (
supported_frameworks,
)
Expand Down Expand Up @@ -72,6 +75,10 @@ def remote(cls_or_method: Any) -> Any:
Returns:
A class or method that can be executed remotely.
"""
# Make sure AnySerializer has been instantiated before wrapping any classes.
if any_serializer.AnySerializer not in any_serializer.AnySerializer._instances:
any_serializer.AnySerializer()

if inspect.isclass(cls_or_method):
return remote_class_decorator(cls_or_method)
else:
Expand Down

0 comments on commit 21686ae

Please sign in to comment.