diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 275f654683..c1ba9739cc 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -173,6 +173,17 @@ def _dashboard_uri(self) -> Optional[str]: url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}" return url + def _log_job_state(self): + """Helper method to log job state.""" + _LOGGER.info( + "%s %s current state:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + self._gca_resource.state, + ) + ) + def _block_until_complete(self): """Helper method to block and check on job until complete. @@ -190,26 +201,13 @@ def _block_until_complete(self): while self.state not in _JOB_COMPLETE_STATES: current_time = time.time() if current_time - previous_time >= log_wait: - _LOGGER.info( - "%s %s current state:\n%s" - % ( - self.__class__.__name__, - self._gca_resource.name, - self._gca_resource.state, - ) - ) + self._log_job_state() log_wait = min(log_wait * multiplier, max_wait) previous_time = current_time time.sleep(wait) - _LOGGER.info( - "%s %s current state:\n%s" - % ( - self.__class__.__name__, - self._gca_resource.name, - self._gca_resource.state, - ) - ) + self._log_job_state() + # Error is only populated when the job state is # JOB_STATE_FAILED or JOB_STATE_CANCELLED. if self._gca_resource.state in _JOB_ERROR_STATES: @@ -845,6 +843,63 @@ def __init__( project=project, location=location ) + self._logged_web_access_uris = set() + + @property + def web_access_uris(self) -> Dict[str, Union[str, Dict[str, str]]]: + """Fetch the runnable job again and return the latest web access uris. + + Returns: + (Dict[str, Union[str, Dict[str, str]]]): + Web access uris of the runnable job. + """ + + # Fetch the Job again for most up-to-date web access uris + self._sync_gca_resource() + return self._get_web_access_uris() + + @abc.abstractmethod + def _get_web_access_uris(self): + """Helper method to get the web access uris of the runnable job""" + pass + + @abc.abstractmethod + def _log_web_access_uris(self): + """Helper method to log the web access uris of the runnable job""" + pass + + def _block_until_complete(self): + """Helper method to block and check on runnable job until complete. + + Raises: + RuntimeError: If job failed or cancelled. + """ + + # Used these numbers so failures surface fast + wait = 5 # start at five seconds + log_wait = 5 + max_wait = 60 * 5 # 5 minute wait + multiplier = 2 # scale wait by 2 every iteration + + previous_time = time.time() + while self.state not in _JOB_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + self._log_job_state() + log_wait = min(log_wait * multiplier, max_wait) + previous_time = current_time + self._log_web_access_uris() + time.sleep(wait) + + self._log_job_state() + + # Error is only populated when the job state is + # JOB_STATE_FAILED or JOB_STATE_CANCELLED. + if self._gca_resource.state in _JOB_ERROR_STATES: + raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error) + else: + _LOGGER.log_action_completed_against_resource("run", "completed", self) + @abc.abstractmethod def run(self) -> None: pass @@ -1046,6 +1101,26 @@ def network(self) -> Optional[str]: self._assert_gca_resource_is_available() return self._gca_resource.job_spec.network + def _get_web_access_uris(self) -> Dict[str, str]: + """Helper method to get the web access uris of the custom job + + Returns: + (Dict[str, str]): + Web access uris of the custom job. + """ + return dict(self._gca_resource.web_access_uris) + + def _log_web_access_uris(self): + """Helper method to log the web access uris of the custom job""" + + for worker, uri in self._get_web_access_uris().items(): + if uri not in self._logged_web_access_uris: + _LOGGER.info( + "%s %s access the interactive shell terminals for the custom job:\n%s:\n%s" + % (self.__class__.__name__, self._gca_resource.name, worker, uri,), + ) + self._logged_web_access_uris.add(uri) + @classmethod def from_local_script( cls, @@ -1250,6 +1325,7 @@ def run( network: Optional[str] = None, timeout: Optional[int] = None, restart_job_on_worker_restart: bool = False, + enable_web_access: bool = False, tensorboard: Optional[str] = None, sync: bool = True, ) -> None: @@ -1271,6 +1347,10 @@ def run( gets restarted. This feature can be used by distributed training jobs that are not resilient to workers leaving and joining a job. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -1304,6 +1384,9 @@ def run( restart_job_on_worker_restart=restart_job_on_worker_restart, ) + if enable_web_access: + self._gca_resource.job_spec.enable_web_access = enable_web_access + if tensorboard: v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob() v1beta1_gca_resource._pb.MergeFromString( @@ -1588,6 +1671,38 @@ def network(self) -> Optional[str]: self._assert_gca_resource_is_available() return getattr(self._gca_resource.trial_job_spec, "network") + def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]: + """Helper method to get the web access uris of the hyperparameter job + + Returns: + (Dict[str, Dict[str, str]]): + Web access uris of the hyperparameter job. + """ + web_access_uris = dict() + for trial in self.trials: + web_access_uris[trial.id] = web_access_uris.get(trial.id, dict()) + for worker, uri in trial.web_access_uris.items(): + web_access_uris[trial.id][worker] = uri + return web_access_uris + + def _log_web_access_uris(self): + """Helper method to log the web access uris of the hyperparameter job""" + + for trial_id, trial_web_access_uris in self._get_web_access_uris().items(): + for worker, uri in trial_web_access_uris.items(): + if uri not in self._logged_web_access_uris: + _LOGGER.info( + "%s %s access the interactive shell terminals for trial - %s:\n%s:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + trial_id, + worker, + uri, + ), + ) + self._logged_web_access_uris.add(uri) + @base.optional_sync() def run( self, @@ -1595,6 +1710,7 @@ def run( network: Optional[str] = None, timeout: Optional[int] = None, # seconds restart_job_on_worker_restart: bool = False, + enable_web_access: bool = False, tensorboard: Optional[str] = None, sync: bool = True, ) -> None: @@ -1616,6 +1732,10 @@ def run( gets restarted. This feature can be used by distributed training jobs that are not resilient to workers leaving and joining a job. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -1649,6 +1769,9 @@ def run( restart_job_on_worker_restart=restart_job_on_worker_restart, ) + if enable_web_access: + self._gca_resource.trial_job_spec.enable_web_access = enable_web_access + if tensorboard: v1beta1_gca_resource = ( gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob() diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index eb7680946b..d750d29b3a 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -27,6 +27,7 @@ from google.cloud.aiplatform import datasets from google.cloud.aiplatform import initializer from google.cloud.aiplatform import models +from google.cloud.aiplatform import jobs from google.cloud.aiplatform import schema from google.cloud.aiplatform import utils from google.cloud.aiplatform.utils import console_utils @@ -1251,6 +1252,7 @@ def __init__( # once Custom Job is known we log the console uri and the tensorboard uri # this flags keeps that state so we don't log it multiple times self._has_logged_custom_job = False + self._logged_web_access_uris = set() @property def network(self) -> Optional[str]: @@ -1382,6 +1384,7 @@ def _prepare_training_task_inputs_and_output_dir( base_output_dir: Optional[str] = None, service_account: Optional[str] = None, network: Optional[str] = None, + enable_web_access: bool = False, tensorboard: Optional[str] = None, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1400,6 +1403,10 @@ def _prepare_training_task_inputs_and_output_dir( should be peered. For example, projects/12345/global/networks/myVPC. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -1437,9 +1444,43 @@ def _prepare_training_task_inputs_and_output_dir( training_task_inputs["network"] = network if tensorboard: training_task_inputs["tensorboard"] = tensorboard + if enable_web_access: + training_task_inputs["enable_web_access"] = enable_web_access return training_task_inputs, base_output_dir + @property + def web_access_uris(self) -> Dict[str, str]: + """Get the web access uris of the backing custom job. + + Returns: + (Dict[str, str]): + Web access uris of the backing custom job. + """ + web_access_uris = dict() + if ( + self._gca_resource.training_task_metadata + and self._gca_resource.training_task_metadata.get("backingCustomJob") + ): + custom_job_resource_name = self._gca_resource.training_task_metadata.get( + "backingCustomJob" + ) + custom_job = jobs.CustomJob.get(resource_name=custom_job_resource_name) + + web_access_uris = dict(custom_job.web_access_uris) + + return web_access_uris + + def _log_web_access_uris(self): + """Helper method to log the web access uris of the backing custom job""" + for worker, uri in self.web_access_uris.items(): + if uri not in self._logged_web_access_uris: + _LOGGER.info( + "%s %s access the interactive shell terminals for the backing custom job:\n%s:\n%s" + % (self.__class__.__name__, self._gca_resource.name, worker, uri,), + ) + self._logged_web_access_uris.add(uri) + def _wait_callback(self): if ( self._gca_resource.training_task_metadata @@ -1453,6 +1494,9 @@ def _wait_callback(self): self._has_logged_custom_job = True + if self._gca_resource.training_task_inputs.get("enable_web_access"): + self._log_web_access_uris() + def _custom_job_console_uri(self) -> str: """Helper method to compose the dashboard uri where custom job can be viewed.""" custom_job_resource_name = self._gca_resource.training_task_metadata.get( @@ -1755,6 +1799,7 @@ def run( test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, timestamp_split_column_name: Optional[str] = None, + enable_web_access: bool = False, tensorboard: Optional[str] = None, sync=True, ) -> Optional[models.Model]: @@ -1974,6 +2019,10 @@ def run( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -2036,6 +2085,7 @@ def run( test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, timestamp_split_column_name=timestamp_split_column_name, + enable_web_access=enable_web_access, tensorboard=tensorboard, reduction_server_container_uri=reduction_server_container_uri if reduction_server_replica_count > 0 @@ -2072,6 +2122,7 @@ def _run( test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, timestamp_split_column_name: Optional[str] = None, + enable_web_access: bool = False, tensorboard: Optional[str] = None, reduction_server_container_uri: Optional[str] = None, sync=True, @@ -2191,6 +2242,10 @@ def _run( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -2259,6 +2314,7 @@ def _run( base_output_dir=base_output_dir, service_account=service_account, network=network, + enable_web_access=enable_web_access, tensorboard=tensorboard, ) @@ -2547,6 +2603,7 @@ def run( test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, timestamp_split_column_name: Optional[str] = None, + enable_web_access: bool = False, tensorboard: Optional[str] = None, sync=True, ) -> Optional[models.Model]: @@ -2759,6 +2816,10 @@ def run( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -2820,6 +2881,7 @@ def run( test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, timestamp_split_column_name=timestamp_split_column_name, + enable_web_access=enable_web_access, tensorboard=tensorboard, reduction_server_container_uri=reduction_server_container_uri if reduction_server_replica_count > 0 @@ -2855,6 +2917,7 @@ def _run( test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, timestamp_split_column_name: Optional[str] = None, + enable_web_access: bool = False, tensorboard: Optional[str] = None, reduction_server_container_uri: Optional[str] = None, sync=True, @@ -2970,6 +3033,10 @@ def _run( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -3032,6 +3099,7 @@ def _run( base_output_dir=base_output_dir, service_account=service_account, network=network, + enable_web_access=enable_web_access, tensorboard=tensorboard, ) @@ -5310,6 +5378,7 @@ def run( test_filter_split: Optional[str] = None, predefined_split_column_name: Optional[str] = None, timestamp_split_column_name: Optional[str] = None, + enable_web_access: bool = False, tensorboard: Optional[str] = None, sync=True, ) -> Optional[models.Model]: @@ -5522,6 +5591,10 @@ def run( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -5578,6 +5651,7 @@ def run( predefined_split_column_name=predefined_split_column_name, timestamp_split_column_name=timestamp_split_column_name, bigquery_destination=bigquery_destination, + enable_web_access=enable_web_access, tensorboard=tensorboard, reduction_server_container_uri=reduction_server_container_uri if reduction_server_replica_count > 0 @@ -5613,6 +5687,7 @@ def _run( predefined_split_column_name: Optional[str] = None, timestamp_split_column_name: Optional[str] = None, bigquery_destination: Optional[str] = None, + enable_web_access: bool = False, tensorboard: Optional[str] = None, reduction_server_container_uri: Optional[str] = None, sync=True, @@ -5715,6 +5790,10 @@ def _run( that piece is ignored by the pipeline. Supported only for tabular and time series Datasets. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell tensorboard (str): Optional. The name of a Vertex AI [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] @@ -5777,6 +5856,7 @@ def _run( base_output_dir=base_output_dir, service_account=service_account, network=network, + enable_web_access=enable_web_access, tensorboard=tensorboard, ) diff --git a/tests/system/aiplatform/test_e2e_tabular.py b/tests/system/aiplatform/test_e2e_tabular.py index 651c737555..5e7b416a8c 100644 --- a/tests/system/aiplatform/test_e2e_tabular.py +++ b/tests/system/aiplatform/test_e2e_tabular.py @@ -109,6 +109,7 @@ def test_end_to_end_tabular(self, shared_state): ds, replica_count=1, model_display_name=self._make_display_name("custom-housing-model"), + enable_web_access=True, sync=False, ) diff --git a/tests/unit/aiplatform/test_custom_job.py b/tests/unit/aiplatform/test_custom_job.py index 040ce3e69a..5fc3243451 100644 --- a/tests/unit/aiplatform/test_custom_job.py +++ b/tests/unit/aiplatform/test_custom_job.py @@ -15,6 +15,7 @@ # import pytest +import logging import copy from importlib import reload @@ -51,7 +52,8 @@ _TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}" _TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_ID}" - +_TEST_ENABLE_WEB_ACCESS = True +_TEST_WEB_ACCESS_URIS = {"workerpool0-0": "uri"} _TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" _TEST_RUN_ARGS = ["-v", "0.1", "--test=arg"] @@ -128,6 +130,18 @@ def _get_custom_job_proto(state=None, name=None, error=None, version="v1"): return custom_job_proto +def _get_custom_job_proto_with_enable_web_access( + state=None, name=None, error=None, version="v1" +): + custom_job_proto = _get_custom_job_proto( + state=state, name=name, error=error, version=version + ) + custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS + if state == gca_job_state_compat.JobState.JOB_STATE_RUNNING: + custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS + return custom_job_proto + + @pytest.fixture def get_custom_job_mock(): with patch.object( @@ -178,6 +192,48 @@ def get_custom_job_mock_with_fail(): yield get_custom_job_mock +@pytest.fixture +def get_custom_job_mock_with_enable_web_access(): + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.side_effect = [ + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_custom_job_mock + + +@pytest.fixture +def get_custom_job_mock_with_enable_web_access_succeeded(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.return_value = _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ) + yield get_custom_job_mock + + @pytest.fixture def create_custom_job_mock(): with mock.patch.object( @@ -190,6 +246,18 @@ def create_custom_job_mock(): yield create_custom_job_mock +@pytest.fixture +def create_custom_job_mock_with_enable_web_access(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_custom_job" + ) as create_custom_job_mock: + create_custom_job_mock.return_value = _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ) + yield create_custom_job_mock + + @pytest.fixture def create_custom_job_mock_fail(): with mock.patch.object( @@ -434,6 +502,72 @@ def test_create_from_local_script_raises_with_no_staging_bucket( container_uri=_TEST_TRAINING_CONTAINER_IMAGE, ) + @pytest.mark.parametrize("sync", [True, False]) + def test_create_custom_job_with_enable_web_access( + self, + create_custom_job_mock_with_enable_web_access, + get_custom_job_mock_with_enable_web_access, + sync, + caplog, + ): + caplog.set_level(logging.INFO) + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = aiplatform.CustomJob( + display_name=_TEST_DISPLAY_NAME, + worker_pool_specs=_TEST_WORKER_POOL_SPEC, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + labels=_TEST_LABELS, + ) + + job.run( + enable_web_access=_TEST_ENABLE_WEB_ACCESS, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + sync=sync, + ) + + job.wait_for_resource_creation() + + job.wait() + + assert "workerpool0-0" in caplog.text + + assert job.resource_name == _TEST_CUSTOM_JOB_NAME + + expected_custom_job = _get_custom_job_proto_with_enable_web_access() + + create_custom_job_mock_with_enable_web_access.assert_called_once_with( + parent=_TEST_PARENT, custom_job=expected_custom_job + ) + + assert job.job_spec == expected_custom_job.job_spec + assert ( + job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + ) + caplog.clear() + + def test_get_web_access_uris(self, get_custom_job_mock_with_enable_web_access): + job = aiplatform.CustomJob.get(_TEST_CUSTOM_JOB_NAME) + while True: + if job.web_access_uris: + assert job.web_access_uris == _TEST_WEB_ACCESS_URIS + break + + def test_get_web_access_uris_job_succeeded( + self, get_custom_job_mock_with_enable_web_access_succeeded + ): + job = aiplatform.CustomJob.get(_TEST_CUSTOM_JOB_NAME) + assert not job.web_access_uris + @pytest.mark.parametrize("sync", [True, False]) def test_create_custom_job_with_tensorboard( self, create_custom_job_v1beta1_mock, get_custom_job_mock, sync diff --git a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py index d82071db4f..1ac9630882 100644 --- a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py +++ b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py @@ -21,19 +21,18 @@ from unittest import mock from unittest.mock import patch +import logging from google.rpc import status_pb2 from google.cloud import aiplatform from google.cloud.aiplatform import hyperparameter_tuning as hpt -from google.cloud.aiplatform.compat.types import job_state as gca_job_state_compat from google.cloud.aiplatform.compat.types import ( encryption_spec as gca_encryption_spec_compat, -) -from google.cloud.aiplatform.compat.types import ( hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat, hyperparameter_tuning_job_v1beta1 as gca_hyperparameter_tuning_job_v1beta1, + job_state as gca_job_state_compat, + study as gca_study_compat, ) -from google.cloud.aiplatform.compat.types import study as gca_study_compat from google.cloud.aiplatform_v1.services.job_service import client as job_service_client from google.cloud.aiplatform_v1beta1.services.job_service import ( client as job_service_client_v1beta1, @@ -128,6 +127,8 @@ encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, ) +_TEST_BASE_TRIAL_PROTO = gca_study_compat.Trial() + def _get_hyperparameter_tuning_job_proto( state=None, name=None, error=None, version="v1" @@ -154,6 +155,29 @@ def _get_hyperparameter_tuning_job_proto( return hyperparameter_tuning_job_proto +def _get_trial_proto(id=None, state=None): + trial_proto = copy.deepcopy(_TEST_BASE_TRIAL_PROTO) + trial_proto.id = id + trial_proto.state = state + if state == gca_study_compat.Trial.State.ACTIVE: + trial_proto.web_access_uris = test_custom_job._TEST_WEB_ACCESS_URIS + return trial_proto + + +def _get_hyperparameter_tuning_job_proto_with_enable_web_access( + state=None, name=None, error=None, version="v1", trials=[] +): + hyperparameter_tuning_job_proto = _get_hyperparameter_tuning_job_proto( + state=state, name=name, error=error, version=version + ) + hyperparameter_tuning_job_proto.trial_job_spec.enable_web_access = ( + test_custom_job._TEST_ENABLE_WEB_ACCESS + ) + if state == gca_job_state_compat.JobState.JOB_STATE_RUNNING: + hyperparameter_tuning_job_proto.trials = trials + return hyperparameter_tuning_job_proto + + @pytest.fixture def get_hyperparameter_tuning_job_mock(): with patch.object( @@ -180,6 +204,86 @@ def get_hyperparameter_tuning_job_mock(): yield get_hyperparameter_tuning_job_mock +@pytest.fixture +def get_hyperparameter_tuning_job_mock_with_enable_web_access(): + with patch.object( + job_service_client.JobServiceClient, "get_hyperparameter_tuning_job" + ) as get_hyperparameter_tuning_job_mock: + get_hyperparameter_tuning_job_mock.side_effect = [ + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ), + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + trials=[ + _get_trial_proto( + id="1", state=gca_study_compat.Trial.State.REQUESTED + ), + ], + ), + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + trials=[ + _get_trial_proto(id="1", state=gca_study_compat.Trial.State.ACTIVE), + ], + ), + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + trials=[ + _get_trial_proto(id="1", state=gca_study_compat.Trial.State.ACTIVE), + ], + ), + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + trials=[ + _get_trial_proto(id="1", state=gca_study_compat.Trial.State.ACTIVE), + ], + ), + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + trials=[ + _get_trial_proto( + id="1", state=gca_study_compat.Trial.State.SUCCEEDED + ), + ], + ), + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + trials=[ + _get_trial_proto( + id="1", state=gca_study_compat.Trial.State.SUCCEEDED + ), + ], + ), + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + trials=[ + _get_trial_proto( + id="1", state=gca_study_compat.Trial.State.SUCCEEDED + ), + ], + ), + _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + trials=[ + _get_trial_proto( + id="1", state=gca_study_compat.Trial.State.SUCCEEDED + ), + ], + ), + ] + yield get_hyperparameter_tuning_job_mock + + @pytest.fixture def get_hyperparameter_tuning_job_mock_with_fail(): with patch.object( @@ -215,6 +319,18 @@ def create_hyperparameter_tuning_job_mock(): yield create_hyperparameter_tuning_job_mock +@pytest.fixture +def create_hyperparameter_tuning_job_mock_with_enable_web_access(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_hyperparameter_tuning_job" + ) as create_hyperparameter_tuning_job_mock: + create_hyperparameter_tuning_job_mock.return_value = _get_hyperparameter_tuning_job_proto_with_enable_web_access( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ) + yield create_hyperparameter_tuning_job_mock + + @pytest.fixture def create_hyperparameter_tuning_job_mock_fail(): with mock.patch.object( @@ -554,3 +670,76 @@ def test_create_hyperparameter_tuning_job_with_tensorboard( assert ( job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_hyperparameter_tuning_job_with_enable_web_access( + self, + create_hyperparameter_tuning_job_mock_with_enable_web_access, + get_hyperparameter_tuning_job_mock_with_enable_web_access, + sync, + caplog, + ): + caplog.set_level(logging.INFO) + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + custom_job = aiplatform.CustomJob( + display_name=test_custom_job._TEST_DISPLAY_NAME, + worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC, + base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR, + ) + + job = aiplatform.HyperparameterTuningJob( + display_name=_TEST_DISPLAY_NAME, + custom_job=custom_job, + metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE}, + parameter_spec={ + "lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"), + "units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"), + "activation": hpt.CategoricalParameterSpec( + values=["relu", "sigmoid", "elu", "selu", "tanh"] + ), + "batch_size": hpt.DiscreteParameterSpec( + values=[16, 32], scale="linear" + ), + }, + parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT, + max_trial_count=_TEST_MAX_TRIAL_COUNT, + max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT, + search_algorithm=_TEST_SEARCH_ALGORITHM, + measurement_selection=_TEST_MEASUREMENT_SELECTION, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + enable_web_access=test_custom_job._TEST_ENABLE_WEB_ACCESS, + sync=sync, + ) + + job.wait() + + assert "workerpool0-0" in caplog.text + + expected_hyperparameter_tuning_job = ( + _get_hyperparameter_tuning_job_proto_with_enable_web_access() + ) + + create_hyperparameter_tuning_job_mock_with_enable_web_access.assert_called_once_with( + parent=_TEST_PARENT, + hyperparameter_tuning_job=expected_hyperparameter_tuning_job, + ) + + assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + assert job.network == _TEST_NETWORK + assert job.trials == [] + + caplog.clear() diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 34b7b63617..cf4b4e7ca8 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -16,8 +16,10 @@ # from distutils import core +import copy import functools import importlib +import logging import pathlib import pytest import subprocess @@ -41,6 +43,7 @@ from google.cloud.aiplatform import schema from google.cloud.aiplatform import training_jobs +from google.cloud.aiplatform_v1.services.job_service import client as job_service_client from google.cloud.aiplatform_v1.services.model_service import ( client as model_service_client, ) @@ -49,10 +52,12 @@ ) from google.cloud.aiplatform_v1.types import ( + custom_job as gca_custom_job, dataset as gca_dataset, encryption_spec as gca_encryption_spec, env_var as gca_env_var, io as gca_io, + job_state as gca_job_state, model as gca_model, pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, @@ -174,6 +179,23 @@ _TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME ) +_TEST_ENABLE_WEB_ACCESS = True +_TEST_WEB_ACCESS_URIS = {"workerpool0-0": "uri"} + +_TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job.CustomJob( + job_spec=gca_custom_job.CustomJobSpec(), +) + + +def _get_custom_job_proto_with_enable_web_access(state=None, name=None, version="v1"): + custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) + custom_job_proto.name = name + custom_job_proto.state = state + + custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS + if state == gca_job_state.JobState.JOB_STATE_RUNNING: + custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS + return custom_job_proto def local_copy_method(path): @@ -246,6 +268,40 @@ def blob_side_effect(name, mock_blob, bucket): yield mock_client_bucket, MockBlob +@pytest.fixture +def mock_get_backing_custom_job_with_enable_web_access(): + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.side_effect = [ + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_PENDING, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, + ), + _get_custom_job_proto_with_enable_web_access( + name=_TEST_CUSTOM_JOB_RESOURCE_NAME, + state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_custom_job_mock + + class TestTrainingScriptPythonPackagerHelpers: def setup_method(self): importlib.reload(initializer) @@ -475,6 +531,19 @@ def make_training_pipeline_with_no_model_upload(state): ) +def make_training_pipeline_with_enable_web_access(state): + training_pipeline = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=state, + training_task_inputs={"enable_web_access": _TEST_ENABLE_WEB_ACCESS}, + ) + if state == gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING: + training_pipeline.training_task_metadata = { + "backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME + } + return training_pipeline + + @pytest.fixture def mock_pipeline_service_get(): with mock.patch.object( @@ -517,6 +586,35 @@ def mock_pipeline_service_get(): yield mock_get_training_pipeline +@pytest.fixture +def mock_pipeline_service_get_with_enable_web_access(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.side_effect = [ + make_training_pipeline_with_enable_web_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ), + make_training_pipeline_with_enable_web_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_enable_web_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_enable_web_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_enable_web_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + make_training_pipeline_with_enable_web_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + ] + + yield mock_get_training_pipeline + + @pytest.fixture def mock_pipeline_service_cancel(): with mock.patch.object( @@ -537,6 +635,17 @@ def mock_pipeline_service_create_with_no_model_to_upload(): yield mock_create_training_pipeline +@pytest.fixture +def mock_pipeline_service_create_with_enable_web_access(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = make_training_pipeline_with_enable_web_access( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ) + yield mock_create_training_pipeline + + @pytest.fixture def mock_pipeline_service_get_with_no_model_to_upload(): with mock.patch.object( @@ -1234,6 +1343,50 @@ def test_run_call_pipeline_service_create_with_no_dataset( assert model_from_job._gca_resource is mock_model_service_get.return_value + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_enable_web_access", + "mock_pipeline_service_get_with_enable_web_access", + "mock_get_backing_custom_job_with_enable_web_access", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_enable_web_access( + self, sync, caplog + ): + + caplog.set_level(logging.INFO) + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + enable_web_access=_TEST_ENABLE_WEB_ACCESS, + sync=sync, + ) + + if not sync: + job.wait() + + print(caplog.text) + assert "workerpool0-0" in caplog.text + assert job._gca_resource == make_training_pipeline_with_enable_web_access( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @pytest.mark.usefixtures( "mock_pipeline_service_create_with_no_model_to_upload", "mock_pipeline_service_get_with_no_model_to_upload", @@ -2577,6 +2730,49 @@ def test_run_call_pipeline_service_create_with_no_dataset( assert model_from_job._gca_resource is mock_model_service_get.return_value + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_enable_web_access", + "mock_pipeline_service_get_with_enable_web_access", + "mock_get_backing_custom_job_with_enable_web_access", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_enable_web_access( + self, sync, caplog + ): + + caplog.set_level(logging.INFO) + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + enable_web_access=_TEST_ENABLE_WEB_ACCESS, + sync=sync, + ) + + if not sync: + job.wait() + + print(caplog.text) + assert "workerpool0-0" in caplog.text + assert job._gca_resource == make_training_pipeline_with_enable_web_access( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @pytest.mark.parametrize("sync", [True, False]) def test_run_returns_none_if_no_model_to_upload( self, @@ -4199,6 +4395,50 @@ def test_run_call_pipeline_service_create_with_no_dataset( assert model_from_job._gca_resource is mock_model_service_get.return_value + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_enable_web_access", + "mock_pipeline_service_get_with_enable_web_access", + "mock_get_backing_custom_job_with_enable_web_access", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_enable_web_access( + self, sync, caplog + ): + + caplog.set_level(logging.INFO) + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + enable_web_access=_TEST_ENABLE_WEB_ACCESS, + sync=sync, + ) + + if not sync: + job.wait() + + print(caplog.text) + assert "workerpool0-0" in caplog.text + assert job._gca_resource == make_training_pipeline_with_enable_web_access( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @pytest.mark.usefixtures( "mock_pipeline_service_create_with_no_model_to_upload", "mock_pipeline_service_get_with_no_model_to_upload",