diff --git a/tests/unit/vertexai/test_any_serializer.py b/tests/unit/vertexai/test_any_serializer.py index 9ea44e9895..d675634438 100644 --- a/tests/unit/vertexai/test_any_serializer.py +++ b/tests/unit/vertexai/test_any_serializer.py @@ -23,6 +23,7 @@ import os from typing import Any +import vertexai from vertexai.preview import developer from vertexai.preview._workflow.serialization_engine import ( any_serializer, @@ -323,6 +324,29 @@ def test_any_serializer_register_predefined_serializers(self, caplog): serializer_instance._serialization_scheme == _TEST_SERIALIZATION_SCHEME ) + def test_any_serializer_with_wrapped_class(self): + # Reset the serializer instances + serializers_base.Serializer._instances = {} + + # Wrap a ML class that we have predefined serializer + unwrapped_keras_class = keras.models.Model + keras.models.Model = vertexai.preview.remote(keras.models.Model) + + try: + # Assert that AnySerializer still registered the original class + serializer_instance = any_serializer.AnySerializer() + assert keras.models.Model not in serializer_instance._serialization_scheme + assert unwrapped_keras_class in serializer_instance._serialization_scheme + assert ( + serializer_instance._serialization_scheme[unwrapped_keras_class] + == serializers.KerasModelSerializer + ) + except Exception as e: + raise e + finally: + # Revert the class after testing + keras.models.Model = unwrapped_keras_class + def test_any_serializer_global_metadata_created( self, mock_cloudpickle_serialize, any_serializer_instance, tmp_path ): diff --git a/vertexai/preview/_workflow/driver/remote.py b/vertexai/preview/_workflow/driver/remote.py index 303e8ef135..e8364e9221 100644 --- a/vertexai/preview/_workflow/driver/remote.py +++ b/vertexai/preview/_workflow/driver/remote.py @@ -23,6 +23,9 @@ from vertexai.preview._workflow.executor import ( training, ) +from vertexai.preview._workflow.serialization_engine import ( + any_serializer, +) from vertexai.preview._workflow.shared import ( supported_frameworks, ) @@ -72,6 +75,10 @@ def remote(cls_or_method: Any) -> Any: Returns: A class or method that can be executed remotely. """ + # Make sure AnySerializer has been instantiated before wrapping any classes. + if any_serializer.AnySerializer not in any_serializer.AnySerializer._instances: + any_serializer.AnySerializer() + if inspect.isclass(cls_or_method): return remote_class_decorator(cls_or_method) else: