Skip to content

Commit

Permalink
feat: Also support unhashable objects to be serialized with extra args
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577998940
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Oct 31, 2023
1 parent 1e4a4ec commit 77a741e
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 21 deletions.
6 changes: 4 additions & 2 deletions tests/unit/vertexai/test_remote_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,9 @@ def test_remote_training_sklearn_with_remote_configs(
_TEST_TRAINING_CONFIG_CONTAINER_URI
)
model.fit.vertex.remote_config.machine_type = _TEST_TRAINING_CONFIG_MACHINE_TYPE
model.fit.vertex.remote_config.serializer_args = {model: {"extra_params": 1}}
model.fit.vertex.remote_config.serializer_args[model] = {"extra_params": 1}
# X_TRAIN is a numpy array that is not hashable.
model.fit.vertex.remote_config.serializer_args[_X_TRAIN] = {"extra_params": 2}

model.fit(_X_TRAIN, _Y_TRAIN)

Expand All @@ -991,7 +993,7 @@ def test_remote_training_sklearn_with_remote_configs(
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
to_serialize=_X_TRAIN,
gcs_path=os.path.join(remote_job_base_path, "input/X"),
**{},
**{"extra_params": 2},
)
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
to_serialize=_Y_TRAIN,
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/vertexai/test_serializers_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from vertexai.preview._workflow.serialization_engine import (
serializers_base,
)


class TestSerializerArgs:
def test_object_id_is_saved(self):
class TestClass:
pass

test_obj = TestClass()
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
assert id(test_obj) in serializer_args
assert test_obj not in serializer_args

def test_getitem_support_original_object(self):
class TestClass:
pass

test_obj = TestClass()
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
assert serializer_args[test_obj] == {"a": 1, "b": 2}

def test_get_support_original_object(self):
class TestClass:
pass

test_obj = TestClass()
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
assert serializer_args.get(test_obj) == {"a": 1, "b": 2}

def test_unhashable_obj_saved_successfully(self):
unhashable = [1, 2, 3]
serializer_args = serializers_base.SerializerArgs()
serializer_args[unhashable] = {"a": 1, "b": 2}
assert id(unhashable) in serializer_args

def test_getitem_support_original_unhashable(self):
unhashable = [1, 2, 3]
serializer_args = serializers_base.SerializerArgs()
serializer_args[unhashable] = {"a": 1, "b": 2}
assert serializer_args[unhashable] == {"a": 1, "b": 2}

def test_get_support_original_unhashable(self):
unhashable = [1, 2, 3]
serializers_args = serializers_base.SerializerArgs()
serializers_args[unhashable] = {"a": 1, "b": 2}
assert serializers_args.get(unhashable) == {"a": 1, "b": 2}
12 changes: 5 additions & 7 deletions vertexai/preview/_workflow/executor/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import re
import sys
import time
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Hashable
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import warnings

from google.api_core import exceptions as api_exceptions
Expand Down Expand Up @@ -495,6 +495,8 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
bound_args = invokable.bound_arguments
config = invokable.vertex_config.remote_config
serializer_args = invokable.vertex_config.remote_config.serializer_args
if not isinstance(serializer_args, serializers_base.SerializerArgs):
raise ValueError("serializer_args must be an instance of SerializerArgs.")

autolog = vertexai.preview.global_config.autolog
service_account = _get_service_account(config, autolog=autolog)
Expand Down Expand Up @@ -609,17 +611,13 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
to_serialize=arg_value,
gcs_path=os.path.join(remote_job_input_path, f"{arg_name}"),
framework=detected_framework,
**serializer_args.get(arg_value, {})
if isinstance(arg_value, Hashable)
else {},
**serializer_args.get(arg_value, {}),
)
else:
serialization_metadata = serializer.serialize(
to_serialize=arg_value,
gcs_path=os.path.join(remote_job_input_path, f"{arg_name}"),
**serializer_args.get(arg_value, {})
if isinstance(arg_value, Hashable)
else {},
**serializer_args.get(arg_value, {}),
)
# serializer.get_dependencies() must be run after serializer.serialize()
requirements += serialization_metadata[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Any, Dict, List, Optional, Type, TypeVar, Union

from google.cloud.aiplatform.utils import gcs_utils

from vertexai.preview._workflow.shared import data_structures

T = TypeVar("T")
SERIALIZATION_METADATA_FILENAME = "serialization_metadata"
Expand All @@ -34,6 +34,9 @@
SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY = "custom_commands"


SerializerArgs = data_structures.IdAsKeyDict


@dataclasses.dataclass
class SerializationMetadata:
"""Metadata of Serializer classes.
Expand Down
44 changes: 33 additions & 11 deletions vertexai/preview/_workflow/shared/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# limitations under the License.
#
import dataclasses
from typing import List, Optional, Dict, Any
from typing import List, Optional
from vertexai.preview._workflow.serialization_engine import (
serializers_base,
)


@dataclasses.dataclass
Expand Down Expand Up @@ -72,16 +75,33 @@ class RemoteConfig(_BaseConfig):
]
# Specify the extra parameters needed for serializing objects.
model.train.vertex.remote_config.serializer_args = {
model: {
"extra_serializer_param1_for_model": param1_value,
"extra_serializer_param2_for_model": param2_value,
from vertexai.preview.developer import SerializerArgs
# You can put all the hashable objects with their arguments in the
# SerializerArgs all at once in a dict. Here we assume "model" is
# hashable.
model.train.vertex.remote_config.serializer_args = SerializerArgs({
model: {
"extra_serializer_param1_for_model": param1_value,
"extra_serializer_param2_for_model": param2_value,
},
hashable_obj2: {
"extra_serializer_param1_for_hashable2": param1_value,
"extra_serializer_param2_for_hashable2": param2_value,
},
})
# Or if the object to be serialized is unhashable, put them into the
# serializer_args one by one. If this is the only use case, there is
# no need to import `SerializerArgs`. Here we assume "X_train" and
# "y_train" is not hashable.
model.train.vertex.remote_config.serializer_args[X_train] = {
"extra_serializer_param1_for_X_train": param1_value,
"extra_serializer_param2_for_X_train": param2_value,
},
X_train: {
"extra_serializer_param1": param1_value,
"extra_serializer_param2": param2_value,
model.train.vertex.remote_config.serializer_args[y_train] = {
"extra_serializer_param1_for_y_train": param1_value,
"extra_serializer_param2_for_y_train": param2_value,
}
}
# Train the model as usual
model.train(X_train, y_train)
Expand Down Expand Up @@ -132,7 +152,7 @@ class RemoteConfig(_BaseConfig):
custom_commands (List[str]):
List of custom commands to be run in the remote job environment.
These commands will be run before the requirements are installed.
serializer_args (Dict[Any, Dict[str, Any]]):
serializer_args: serializers_base.SerializerArgs:
Map from object to extra arguments when serializing the object. The extra
arguments is a dictionary from the argument names to the argument values.
"""
Expand All @@ -143,7 +163,9 @@ class RemoteConfig(_BaseConfig):
service_account: Optional[str] = None
requirements: List[str] = dataclasses.field(default_factory=list)
custom_commands: List[str] = dataclasses.field(default_factory=list)
serializer_args: Dict[Any, Dict[str, Any]] = dataclasses.field(default_factory=dict)
serializer_args: serializers_base.SerializerArgs = dataclasses.field(
default_factory=serializers_base.SerializerArgs
)


@dataclasses.dataclass
Expand Down
68 changes: 68 additions & 0 deletions vertexai/preview/_workflow/shared/data_structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


class IdAsKeyDict(dict):
"""Customized dict that maps each key to its id before storing the data.
This subclass of dict still allows one to use the original key during
subscription ([] operator) or via `get()` method. But under the hood, the
keys are the ids of the original keys.
Example:
# add some hashable objects (key1 and key2) to the dict
id_as_key_dict = IdAsKeyDict({key1: value1, key2: value2})
# add a unhashable object (key3) to the dict
id_as_key_dict[key3] = value3
# can access the value via subscription using the original key
assert id_as_key_dict[key1] == value1
assert id_as_key_dict[key2] == value2
assert id_as_key_dict[key3] == value3
# can access the value via get method using the original key
assert id_as_key_dict.get(key1) == value1
assert id_as_key_dict.get(key2) == value2
assert id_as_key_dict.get(key3) == value3
# but the original keys are not in the dict - the ids are
assert id(key1) in id_as_key_dict
assert id(key2) in id_as_key_dict
assert id(key3) in id_as_key_dict
assert key1 not in id_as_key_dict
assert key2 not in id_as_key_dict
assert key3 not in id_as_key_dict
"""

def __init__(self, *args, **kwargs):
internal_dict = {}
for arg in args:
for k, v in arg.items():
internal_dict[id(k)] = v
for k, v in kwargs.items():
internal_dict[id(k)] = v
super().__init__(internal_dict)

def __getitem__(self, _key):
internal_key = id(_key)
return super().__getitem__(internal_key)

def __setitem__(self, _key, _value):
internal_key = id(_key)
return super().__setitem__(internal_key, _value)

def get(self, key, default=None):
internal_key = id(key)
return super().get(internal_key, default)
2 changes: 2 additions & 0 deletions vertexai/preview/developer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PersistentResourceConfig = configs.PersistentResourceConfig
Serializer = serializers_base.Serializer
SerializationMetadata = serializers_base.SerializationMetadata
SerializerArgs = serializers_base.SerializerArgs
RemoteConfig = configs.RemoteConfig
WorkerPoolSpec = remote_specs.WorkerPoolSpec
WorkerPoolSepcs = remote_specs.WorkerPoolSpecs
Expand All @@ -41,6 +42,7 @@
"PersistentResourceConfig",
"register_serializer",
"Serializer",
"SerializerArgs",
"SerializationMetadata",
"RemoteConfig",
"WorkerPoolSpec",
Expand Down

0 comments on commit 77a741e

Please sign in to comment.