From 8b0add169ebd0683b56dbe3b643d533ebbd5e1ca Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Wed, 13 Sep 2023 15:15:17 -0700 Subject: [PATCH] feat: add Custom Job support to from_pretrained PiperOrigin-RevId: 565175389 --- google/cloud/aiplatform/jobs.py | 13 ++ tests/unit/vertexai/conftest.py | 1 + tests/unit/vertexai/test_model_utils.py | 202 ++++++++++++++++++ tests/unit/vertexai/test_remote_training.py | 13 ++ vertexai/preview/_workflow/driver/__init__.py | 8 +- .../preview/_workflow/executor/__init__.py | 7 +- .../preview/_workflow/executor/prediction.py | 6 +- .../executor/remote_container_training.py | 22 +- .../preview/_workflow/executor/training.py | 49 ++++- .../_workflow/executor/training_script.py | 6 +- .../preview/_workflow/launcher/__init__.py | 10 +- .../preview/_workflow/shared/model_utils.py | 178 ++++++++++++--- 12 files changed, 457 insertions(+), 58 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 6c0afcec27..1636430018 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -87,6 +87,19 @@ gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED, ) +_JOB_PENDING_STATES = ( + gca_job_state.JobState.JOB_STATE_QUEUED, + gca_job_state.JobState.JOB_STATE_PENDING, + gca_job_state.JobState.JOB_STATE_RUNNING, + gca_job_state.JobState.JOB_STATE_CANCELLING, + gca_job_state.JobState.JOB_STATE_UPDATING, + gca_job_state_v1beta1.JobState.JOB_STATE_QUEUED, + gca_job_state_v1beta1.JobState.JOB_STATE_PENDING, + gca_job_state_v1beta1.JobState.JOB_STATE_RUNNING, + gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLING, + gca_job_state_v1beta1.JobState.JOB_STATE_UPDATING, +) + # _block_until_complete wait times _JOB_WAIT_TIME = 5 # start at five seconds _LOG_WAIT_TIME = 5 diff --git a/tests/unit/vertexai/conftest.py b/tests/unit/vertexai/conftest.py index 2113267a49..4994248eed 100644 --- a/tests/unit/vertexai/conftest.py +++ b/tests/unit/vertexai/conftest.py @@ -80,6 +80,7 @@ output_uri_prefix=_TEST_BASE_OUTPUT_DIR ), }, + labels={"trained_by_vertex_ai": "true"}, ) diff --git a/tests/unit/vertexai/test_model_utils.py b/tests/unit/vertexai/test_model_utils.py index 8e0ba751f8..7c0ebf548a 100644 --- a/tests/unit/vertexai/test_model_utils.py +++ b/tests/unit/vertexai/test_model_utils.py @@ -22,9 +22,19 @@ import vertexai from vertexai.preview._workflow.serialization_engine import ( any_serializer, + serializers_base, +) +from google.cloud.aiplatform.compat.services import job_service_client +from google.cloud.aiplatform.compat.types import ( + job_state as gca_job_state, + custom_job as gca_custom_job, + io as gca_io, ) import pytest +import cloudpickle +import numpy as np +import sklearn from sklearn.linear_model import _logistic import tensorflow import torch @@ -45,6 +55,9 @@ _MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456" _REWRAPPER = "rewrapper" +# customJob constants +_TEST_CUSTOM_JOB_RESOURCE_NAME = "projects/123/locations/us-central1/customJobs/456" + @pytest.fixture def mock_serialize_model(): @@ -123,6 +136,126 @@ def mock_deserialize_model_exception(): yield mock_deserialize_model_exception +@pytest.fixture +def mock_any_serializer_serialize_sklearn(): + with mock.patch.object( + any_serializer.AnySerializer, + "serialize", + side_effect=[ + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"scikit-learn=={sklearn.__version__}" + ] + }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, + ], + ) as mock_any_serializer_serialize: + yield mock_any_serializer_serialize + + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345" +_TEST_BUCKET_NAME = "gs://test_bucket" +_TEST_BASE_OUTPUT_DIR = f"{_TEST_BUCKET_NAME}/test_base_output_dir" + +_TEST_INPUTS = [ + "--arg_0=string_val_0", + "--arg_1=string_val_1", + "--arg_2=int_val_0", + "--arg_3=int_val_1", +] +_TEST_IMAGE_URI = "test_image_uri" +_TEST_MACHINE_TYPE = "test_machine_type" +_TEST_WORKER_POOL_SPEC = [ + { + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + }, + "replica_count": 1, + "container_spec": { + "image_uri": _TEST_IMAGE_URI, + "args": _TEST_INPUTS, + }, + } +] +_TEST_CUSTOM_JOB_PROTO = gca_custom_job.CustomJob( + display_name=_TEST_DISPLAY_NAME, + job_spec={ + "worker_pool_specs": _TEST_WORKER_POOL_SPEC, + "base_output_directory": gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + }, + labels={"trained_by_vertex_ai": "true"}, +) + + +@pytest.fixture +def mock_get_custom_job_pending(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as mock_get_custom_job: + + mock_get_custom_job.side_effect = [ + gca_custom_job.CustomJob( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_RUNNING, + display_name=_TEST_DISPLAY_NAME, + job_spec={ + "worker_pool_specs": _TEST_WORKER_POOL_SPEC, + "base_output_directory": gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + }, + labels={"trained_by_vertex_ai": "true"}, + ), + gca_custom_job.CustomJob( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, + display_name=_TEST_DISPLAY_NAME, + job_spec={ + "worker_pool_specs": _TEST_WORKER_POOL_SPEC, + "base_output_directory": gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + }, + labels={"trained_by_vertex_ai": "true"}, + ), + ] + yield mock_get_custom_job + + +@pytest.fixture +def mock_get_custom_job_failed(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as mock_get_custom_job: + custom_job_proto = _TEST_CUSTOM_JOB_PROTO + custom_job_proto.name = _TEST_CUSTOM_JOB_RESOURCE_NAME + custom_job_proto.state = gca_job_state.JobState.JOB_STATE_FAILED + mock_get_custom_job.return_value = custom_job_proto + yield mock_get_custom_job + + @pytest.mark.usefixtures("google_auth_mock") class TestModelUtils: def setup_method(self): @@ -289,3 +422,72 @@ def test_local_model_from_pretrained_fail(self): with pytest.raises(ValueError): vertexai.preview.from_pretrained(model_name=_MODEL_RESOURCE_NAME) + + @pytest.mark.usefixtures( + "mock_get_vertex_model", + "mock_get_custom_job_succeeded", + ) + def test_custom_job_from_pretrained_succeed(self, mock_deserialize_model): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + local_model = vertexai.preview.from_pretrained( + custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME + ) + assert local_model == _SKLEARN_MODEL + assert 2 == mock_deserialize_model.call_count + + mock_deserialize_model.assert_has_calls( + calls=[ + mock.call( + f"{_TEST_BASE_OUTPUT_DIR}/output/output_estimator", + ), + ], + any_order=True, + ) + + @pytest.mark.usefixtures( + "mock_get_vertex_model", + "mock_get_custom_job_pending", + "mock_cloud_logging_list_entries", + ) + def test_custom_job_from_pretrained_logs_and_blocks_until_complete_on_pending_job( + self, mock_deserialize_model + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + local_model = vertexai.preview.from_pretrained( + custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME + ) + assert local_model == _SKLEARN_MODEL + assert 2 == mock_deserialize_model.call_count + + mock_deserialize_model.assert_has_calls( + calls=[ + mock.call( + f"{_TEST_BASE_OUTPUT_DIR}/output/output_estimator", + ), + ], + any_order=True, + ) + + @pytest.mark.usefixtures("mock_get_vertex_model", "mock_get_custom_job_failed") + def test_custom_job_from_pretrained_fails_on_errored_job(self): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + with pytest.raises(ValueError) as err_msg: + vertexai.preview.from_pretrained( + custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME + ) + assert "did not complete" in err_msg diff --git a/tests/unit/vertexai/test_remote_training.py b/tests/unit/vertexai/test_remote_training.py index faae5d2e92..b8b0874078 100644 --- a/tests/unit/vertexai/test_remote_training.py +++ b/tests/unit/vertexai/test_remote_training.py @@ -388,6 +388,7 @@ def _get_custom_job_proto( env.append( {"name": metadata_constants.ENV_EXPERIMENT_RUN_KEY, "value": experiment_run} ) + job.labels = ({"trained_by_vertex_ai": "true"},) return job @@ -480,6 +481,12 @@ def mock_any_serializer_serialize_sklearn(): f"cloudpickle=={cloudpickle.__version__}", ] }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, ], ) as mock_any_serializer_serialize: yield mock_any_serializer_serialize @@ -557,6 +564,12 @@ def mock_any_serializer_serialize_keras(): f"cloudpickle=={cloudpickle.__version__}", ] }, + { + serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [ + f"numpy=={np.__version__}", + f"cloudpickle=={cloudpickle.__version__}", + ] + }, ], ) as mock_any_serializer_serialize: yield mock_any_serializer_serialize diff --git a/vertexai/preview/_workflow/driver/__init__.py b/vertexai/preview/_workflow/driver/__init__.py index 68d35b9bd8..1ee530c9e2 100644 --- a/vertexai/preview/_workflow/driver/__init__.py +++ b/vertexai/preview/_workflow/driver/__init__.py @@ -241,7 +241,7 @@ def invoke(self, invokable: shared._Invokable) -> Any: ): rewrapper = _unwrapper(invokable.instance) - result = self._launch(invokable) + result = self._launch(invokable, rewrapper) # rewrap the original instance if rewrapper and invokable.instance is not None: @@ -255,12 +255,14 @@ def invoke(self, invokable: shared._Invokable) -> Any: return result - def _launch(self, invokable: shared._Invokable) -> Any: + def _launch(self, invokable: shared._Invokable, rewrapper: Any) -> Any: """ Launches an invokable. """ return self._launcher.launch( - invokable=invokable, global_remote=vertexai.preview.global_config.remote + invokable=invokable, + global_remote=vertexai.preview.global_config.remote, + rewrapper=rewrapper, ) diff --git a/vertexai/preview/_workflow/executor/__init__.py b/vertexai/preview/_workflow/executor/__init__.py index 32001a602c..28d01815b1 100644 --- a/vertexai/preview/_workflow/executor/__init__.py +++ b/vertexai/preview/_workflow/executor/__init__.py @@ -37,7 +37,7 @@ def local_execute(self, invokable: shared._Invokable) -> Any: *invokable.bound_arguments.args, **invokable.bound_arguments.kwargs ) - def remote_execute(self, invokable: shared._Invokable) -> Any: + def remote_execute(self, invokable: shared._Invokable, rewrapper: Any) -> Any: if invokable.remote_executor not in ( remote_container_training.train, training.remote_training, @@ -45,7 +45,10 @@ def remote_execute(self, invokable: shared._Invokable) -> Any: ): raise ValueError(f"{invokable.remote_executor} is not supported.") - return invokable.remote_executor(invokable) + if invokable.remote_executor == remote_container_training.train: + invokable.remote_executor(invokable) + else: + return invokable.remote_executor(invokable, rewrapper=rewrapper) _workflow_executor = _WorkflowExecutor() diff --git a/vertexai/preview/_workflow/executor/prediction.py b/vertexai/preview/_workflow/executor/prediction.py index 0f011bda26..b64f2221ca 100644 --- a/vertexai/preview/_workflow/executor/prediction.py +++ b/vertexai/preview/_workflow/executor/prediction.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import Any + from vertexai.preview._workflow import ( shared, ) @@ -20,9 +22,9 @@ ) -def remote_prediction(invokable: shared._Invokable): +def remote_prediction(invokable: shared._Invokable, rewrapper: Any): """Wrapper function that makes a method executable by Vertex CustomJob.""" - predictions = training.remote_training(invokable=invokable) + predictions = training.remote_training(invokable=invokable, rewrapper=rewrapper) return predictions diff --git a/vertexai/preview/_workflow/executor/remote_container_training.py b/vertexai/preview/_workflow/executor/remote_container_training.py index cb5e3f0983..e89fe0bc3b 100644 --- a/vertexai/preview/_workflow/executor/remote_container_training.py +++ b/vertexai/preview/_workflow/executor/remote_container_training.py @@ -22,11 +22,7 @@ import vertexai from vertexai.preview._workflow import shared from vertexai.preview.developer import remote_specs - - -_CUSTOM_JOB_DIR = "custom_job" -_INPUT_DIR = "input" -_OUTPUT_DIR = "output" +from vertexai.preview._workflow.shared import model_utils # job_dir container argument name _JOB_DIR = "job_dir" @@ -141,8 +137,8 @@ def train(invokable: shared._Invokable): "No default staging bucket set. " "Please call `vertexai.init(staging_bucket='gs://my-bucket')." ) - input_dir = remote_specs._gen_gcs_path(staging_bucket, _INPUT_DIR) - output_dir = remote_specs._gen_gcs_path(staging_bucket, _OUTPUT_DIR) + input_dir = remote_specs._gen_gcs_path(staging_bucket, model_utils._INPUT_DIR) + output_dir = remote_specs._gen_gcs_path(staging_bucket, model_utils._OUTPUT_DIR) # Creates a complete set of binding. instance_binding = invokable.instance._binding @@ -153,7 +149,9 @@ def train(invokable: shared._Invokable): # If a container accepts a job_dir argument and the user does not specify # it, set job_dir based on the staging bucket. if _JOB_DIR in binding and not binding[_JOB_DIR]: - binding[_JOB_DIR] = remote_specs._gen_gcs_path(staging_bucket, _CUSTOM_JOB_DIR) + binding[_JOB_DIR] = remote_specs._gen_gcs_path( + staging_bucket, model_utils._CUSTOM_JOB_DIR + ) # Formats arguments. formatted_args = {} @@ -200,8 +198,12 @@ def train(invokable: shared._Invokable): display_name=f"{invokable.instance.__class__.__name__}-{display_name}" f"-{uuid.uuid4()}", worker_pool_specs=worker_pool_specs, - base_output_dir=remote_specs._gen_gcs_path(staging_bucket, _CUSTOM_JOB_DIR), - staging_bucket=remote_specs._gen_gcs_path(staging_bucket, _CUSTOM_JOB_DIR), + base_output_dir=remote_specs._gen_gcs_path( + staging_bucket, model_utils._CUSTOM_JOB_DIR + ), + staging_bucket=remote_specs._gen_gcs_path( + staging_bucket, model_utils._CUSTOM_JOB_DIR + ), ) job.run() diff --git a/vertexai/preview/_workflow/executor/training.py b/vertexai/preview/_workflow/executor/training.py index 1197c76e58..50b1a5121c 100644 --- a/vertexai/preview/_workflow/executor/training.py +++ b/vertexai/preview/_workflow/executor/training.py @@ -42,6 +42,7 @@ from vertexai.preview._workflow.shared import ( supported_frameworks, ) +from vertexai.preview._workflow.shared import model_utils from vertexai.preview.developer import remote_specs from packaging import version @@ -460,7 +461,31 @@ def _get_remote_logs_until_complete( ) -def remote_training(invokable: shared._Invokable): +def _set_job_labels(method_name: str) -> Dict[str, str]: + """Helper method to set the label for the CustomJob. + + Remote training, feature transform, and prediction jobs should each have + different labels. + + Args: + method_Name (str): + Required. The method name used to invoke the remote job. + + Returns: + A dictionary of the label key/value to use for the CustomJob. + """ + + if method_name in supported_frameworks.REMOTE_TRAINING_STATEFUL_OVERRIDE_LIST: + return {"trained_by_vertex_ai": "true"} + + if method_name in supported_frameworks.REMOTE_TRAINING_FUNCTIONAL_OVERRIDE_LIST: + return {"feature_transformed_by_vertex_ai": "true"} + + if method_name in supported_frameworks.REMOTE_PREDICTION_OVERRIDE_LIST: + return {"predicted_by_vertex_ai": "true"} + + +def remote_training(invokable: shared._Invokable, rewrapper: Any): """Wrapper function that makes a method executable by Vertex CustomJob.""" self = invokable.instance @@ -521,7 +546,9 @@ def remote_training(invokable: shared._Invokable): remote_job = f"remote-job-{utils.timestamped_unique_name()}" remote_job_base_path = os.path.join(staging_bucket, remote_job) remote_job_input_path = os.path.join(remote_job_base_path, "input") - remote_job_output_path = os.path.join(remote_job_base_path, "output") + remote_job_output_path = model_utils._generate_remote_job_output_path( + remote_job_base_path + ) detected_framework = None if supported_frameworks._is_sklearn(self): @@ -655,6 +682,15 @@ def remote_training(invokable: shared._Invokable): command = ["sh", "-c", " ".join(command)] + labels = _set_job_labels(method_name) + + # serialize rewrapper, this is needed to load a model from a CustomJob + filepath = os.path.join( + remote_job_output_path, + model_utils._REWRAPPER_NAME, + ) + serializer.serialize(rewrapper, filepath) + # create & run the CustomJob # disable CustomJob logs @@ -667,6 +703,7 @@ def remote_training(invokable: shared._Invokable): worker_pool_specs=_get_worker_pool_specs(config, container_uri, command), base_output_dir=remote_job_base_path, staging_bucket=remote_job_base_path, + labels=labels, ) job.submit( @@ -698,7 +735,7 @@ def remote_training(invokable: shared._Invokable): # retrieve the result from gcs to local if method_name in supported_frameworks.REMOTE_TRAINING_STATEFUL_OVERRIDE_LIST: estimator = serializer.deserialize( - os.path.join(remote_job_output_path, "output_estimator"), + os.path.join(remote_job_output_path, model_utils._OUTPUT_ESTIMATOR_DIR), ) if supported_frameworks._is_sklearn(self): @@ -715,7 +752,9 @@ def remote_training(invokable: shared._Invokable): _update_lightning_trainer_inplace(self, estimator) # deserialize and update the trained model as well trained_model = serializer.deserialize( - os.path.join(remote_job_output_path, "output_estimator", "model") + os.path.join( + remote_job_output_path, model_utils._OUTPUT_ESTIMATOR_DIR, "model" + ) ) _update_torch_model_inplace(serialized_args["model"], trained_model) else: @@ -727,7 +766,7 @@ def remote_training(invokable: shared._Invokable): if method_name in supported_frameworks.REMOTE_PREDICTION_OVERRIDE_LIST: predictions = serializer.deserialize( - os.path.join(remote_job_output_path, "output_predictions") + os.path.join(remote_job_output_path, model_utils._OUTPUT_PREDICTIONS_DIR) ) return predictions diff --git a/vertexai/preview/_workflow/executor/training_script.py b/vertexai/preview/_workflow/executor/training_script.py index 981e5d7eb7..76822748ac 100644 --- a/vertexai/preview/_workflow/executor/training_script.py +++ b/vertexai/preview/_workflow/executor/training_script.py @@ -28,6 +28,7 @@ from vertexai.preview._workflow.shared import ( constants, supported_frameworks, + model_utils, ) from vertexai.preview.developer import remote_specs @@ -205,7 +206,7 @@ def main(argv): # for distributed training, chief saves output to specified output # directory while non-chief workers save output to temp directory. output_path = remote_specs._get_output_path_for_distributed_training( - _OUTPUT_PATH.value, "output_estimator" + _OUTPUT_PATH.value, model_utils._OUTPUT_ESTIMATOR_DIR ) serializer.serialize(estimator, output_path) @@ -216,7 +217,8 @@ def main(argv): # for remote prediction if _METHOD_NAME.value in supported_frameworks.REMOTE_PREDICTION_OVERRIDE_LIST: serializer.serialize( - output, os.path.join(_OUTPUT_PATH.value, "output_predictions") + output, + os.path.join(_OUTPUT_PATH.value, model_utils._OUTPUT_PREDICTIONS_DIR), ) output_path = remote_specs._get_output_path_for_distributed_training( diff --git a/vertexai/preview/_workflow/launcher/__init__.py b/vertexai/preview/_workflow/launcher/__init__.py index 709dbd017e..4ddf9980d6 100644 --- a/vertexai/preview/_workflow/launcher/__init__.py +++ b/vertexai/preview/_workflow/launcher/__init__.py @@ -24,12 +24,12 @@ class _WorkflowLauncher: """Launches workflows either locally or remotely.""" - def launch(self, invokable: shared._Invokable, global_remote: bool): + def launch(self, invokable: shared._Invokable, global_remote: bool, rewrapper: Any): local_remote = invokable.vertex_config.remote if local_remote or (local_remote is None and global_remote): - result = self._remote_launch(invokable) + result = self._remote_launch(invokable, rewrapper) else: for _, arg in invokable.bound_arguments.arguments.items(): if "bigframes" in repr(type(arg)): @@ -39,8 +39,10 @@ def launch(self, invokable: shared._Invokable, global_remote: bool): result = self._local_launch(invokable) return result - def _remote_launch(self, invokable: shared._Invokable) -> Any: - result = executor._workflow_executor.remote_execute(invokable) + def _remote_launch(self, invokable: shared._Invokable, rewrapper: Any) -> Any: + result = executor._workflow_executor.remote_execute( + invokable, rewrapper=rewrapper + ) # TODO(b/277343861) workflow tracking goes here # E.g., initializer.global_config.workflow.add_remote_step(invokable, result) diff --git a/vertexai/preview/_workflow/shared/model_utils.py b/vertexai/preview/_workflow/shared/model_utils.py index f901766262..858c43486d 100644 --- a/vertexai/preview/_workflow/shared/model_utils.py +++ b/vertexai/preview/_workflow/shared/model_utils.py @@ -22,22 +22,107 @@ """ import os -from typing import Any, Union +from typing import Any, Optional, Union from google.cloud import aiplatform +from google.cloud.aiplatform import base from google.cloud.aiplatform import utils +from google.cloud.aiplatform import jobs as aiplatform_jobs import vertexai from vertexai.preview._workflow import driver from vertexai.preview._workflow.serialization_engine import ( any_serializer, serializers_base, ) +from vertexai.preview._workflow.executor import training +from google.cloud.aiplatform.compat.types import job_state as gca_job_state + _SKLEARN_FILE_NAME = "model.pkl" _TF_DIR_NAME = "saved_model" _PYTORCH_FILE_NAME = "model.mar" _REWRAPPER_NAME = "rewrapper" +_CUSTOM_JOB_DIR = "custom_job" +_INPUT_DIR = "input" +_OUTPUT_DIR = "output" +_OUTPUT_ESTIMATOR_DIR = "output_estimator" +_OUTPUT_PREDICTIONS_DIR = "output_predictions" + + +_LOGGER = base.Logger("vertexai.remote_execution") + + +def _get_model_file_from_image_uri(container_image_uri: str) -> str: + """Gets the model file from the container image URI. + + Args: + container_image_uri (str): + The image URI of the container from the training job. + + Returns: + str: + The model file name. + """ + + # sklearn, TF, PyTorch model extensions for retraining. + # PyTorch serv will need model.mar + if "tf" in container_image_uri: + return "" + elif "sklearn" in container_image_uri: + return _SKLEARN_FILE_NAME + elif "pytorch" in container_image_uri: + # Assume the pretrained model will be pulled for uptraining. + return _PYTORCH_FILE_NAME + else: + raise ValueError("Support loading PyTorch, scikit-learn and TensorFlow only.") + + +def _verify_custom_job(job: aiplatform.CustomJob) -> None: + """Verifies the provided CustomJob was created with SDK 2.0. + + Args: + job (aiplatform.CustomJob): + The CustomJob resource + + Raises: + If the provided job wasn't created with SDK 2.0. + """ + + if ( + not job.labels.get("trained_by_vertex_ai") + or job.labels.get("trained_by_vertex_ai") != "true" + ): + raise ValueError( + "This job wasn't created with SDK remote training, or it was created with a Vertex SDK version <= 1.32.0" + ) + + +def _generate_remote_job_output_path(base_gcs_dir: str) -> str: + """Generates the GCS output path of the remote training job. + + Args: + base_gcs_dir (str): + The base GCS directory for the remote training job. + """ + return os.path.join(base_gcs_dir, _OUTPUT_DIR) + + +def _get_model_from_successful_custom_job( + job_dir: str, +) -> Union["sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module"]: + + serializer = any_serializer.AnySerializer() + + model = serializer.deserialize( + os.path.join(_generate_remote_job_output_path(job_dir), _OUTPUT_ESTIMATOR_DIR) + ) + rewrapper = serializer.deserialize( + os.path.join(_generate_remote_job_output_path(job_dir), _REWRAPPER_NAME) + ) + rewrapper(model) + return model + def _register_sklearn_model( model: "sklearn.base.BaseEstimator", # noqa: F821 @@ -212,53 +297,86 @@ def register( def from_pretrained( *, - model_name: str, + model_name: Optional[str] = None, + custom_job_name: Optional[str] = None, ) -> Union["sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module"]: # noqa: F821 - """Pulls a model from Model Registry for retraining. + """Pulls a model from Model Registry or from a CustomJob ID for retraining. + + The returned model is wrapped with a Vertex wrapper for running remote jobs on Vertex, + unless an unwrapped model was registered to Model Registry. Args: model_name (str): - Required. The resource ID or fully qualified resource name of a registered model. + Optional. The resource ID or fully qualified resource name of a registered model. Format: "12345678910" or - "projects/123/locations/us-central1/models/12345678910@1". + "projects/123/locations/us-central1/models/12345678910@1". One of `model_name` or + `custom_job_name` is required. + custom_job_name (str): + Optional. The resource ID or fully qualified resource name of a CustomJob created + with Vertex SDK remote training. If the job has completed successfully, this will load + the trained model created in the CustomJob. One of `model_name` or + `custom_job_name` is required. Returns: model: local model for uptraining. Raises: - ValueError: If registered model is not registered through `vertexai.preview.register` + ValueError: + If registered model is not registered through `vertexai.preview.register` + If custom job was not created with Vertex SDK remote training + If both or neither model_name or custom_job_name are provided """ + if not model_name and not custom_job_name or (model_name and custom_job_name): + raise ValueError("Exactly one of `model` or `custom_job` should be provided.") project = vertexai.preview.global_config.project location = vertexai.preview.global_config.location credentials = vertexai.preview.global_config.credentials - vertex_model = aiplatform.Model( - model_name, project=project, location=location, credentials=credentials - ) - if vertex_model.labels.get("registered_by_vertex_ai") != "true": - raise ValueError( - f"The model {model_name} is not registered through `vertexai.preview.register`." + if model_name: + + vertex_model = aiplatform.Model( + model_name, project=project, location=location, credentials=credentials ) - artifact_uri = vertex_model.uri + if vertex_model.labels.get("registered_by_vertex_ai") != "true": + raise ValueError( + f"The model {model_name} is not registered through `vertex_ai.register`." + ) - # sklearn, TF, PyTorch model extensions for retraining. - # PyTorch serv will need model.mar - if "tf" in vertex_model.container_spec.image_uri: - model_file = "" - elif "sklearn" in vertex_model.container_spec.image_uri: - model_file = _SKLEARN_FILE_NAME - elif "pytorch" in vertex_model.container_spec.image_uri: - # Assume the pretrained model will be pulled for uptraining. - model_file = _PYTORCH_FILE_NAME - else: - raise ValueError("Support loading PyTorch, scikit-learn and TensorFlow only.") + artifact_uri = vertex_model.uri + model_file = _get_model_file_from_image_uri( + vertex_model.container_spec.image_uri + ) - serializer = any_serializer.AnySerializer() - model = serializer.deserialize(os.path.join(artifact_uri, model_file)) + serializer = any_serializer.AnySerializer() + model = serializer.deserialize(os.path.join(artifact_uri, model_file)) - rewrapper = serializer.deserialize(os.path.join(artifact_uri, _REWRAPPER_NAME)) + rewrapper = serializer.deserialize(os.path.join(artifact_uri, _REWRAPPER_NAME)) - # Rewrap model (in-place) for following remote training. - rewrapper(model) - return model + # Rewrap model (in-place) for following remote training. + rewrapper(model) + return model + + if custom_job_name: + job = aiplatform.CustomJob.get( + custom_job_name, project=project, location=location, credentials=credentials + ) + job_state = job.state + + _verify_custom_job(job) + job_dir = job.job_spec.base_output_directory.output_uri_prefix + + if job_state in aiplatform_jobs._JOB_PENDING_STATES: + _LOGGER.info( + f"The CustomJob {job.name} is still running. When the job has completed successfully, your model will be returned." + ) + training._get_remote_logs_until_complete(job) + # Get the new job state after it has completed + job_state = job.state + + if job_state == gca_job_state.JobState.JOB_STATE_SUCCEEDED: + return _get_model_from_successful_custom_job(job_dir) + else: + raise ValueError( + "The provided job did not complete successfully. Please provide a pending or successful customJob ID." + )