diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 6ebcc602ce..5291ab36db 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -1108,7 +1108,7 @@ def _get_web_access_uris(self) -> Dict[str, str]: (Dict[str, str]): Web access uris of the custom job. """ - return self._gca_resource.web_access_uris + return dict(self._gca_resource.web_access_uris) def _log_web_access_uris(self): """Helper method to log the web access uris of the custom job""" @@ -1678,9 +1678,9 @@ def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]: (Dict[str, Dict[str, str]]): Web access uris of the hyperparameter job. """ - web_access_uris = {} + web_access_uris = dict() for trial in self.trials: - web_access_uris[trial.id] = web_access_uris.get(trial.id, {}) + web_access_uris[trial.id] = web_access_uris.get(trial.id, dict()) for worker, uri in trial.web_access_uris.items(): web_access_uris[trial.id][worker] = uri return web_access_uris diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 94b952d454..d750d29b3a 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1457,7 +1457,7 @@ def web_access_uris(self) -> Dict[str, str]: (Dict[str, str]): Web access uris of the backing custom job. """ - web_access_uris = {} + web_access_uris = dict() if ( self._gca_resource.training_task_metadata and self._gca_resource.training_task_metadata.get("backingCustomJob") @@ -1467,7 +1467,7 @@ def web_access_uris(self) -> Dict[str, str]: ) custom_job = jobs.CustomJob.get(resource_name=custom_job_resource_name) - web_access_uris = custom_job.web_access_uris + web_access_uris = dict(custom_job.web_access_uris) return web_access_uris diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 810b9ef0a9..cf4b4e7ca8 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -524,11 +524,13 @@ def make_training_pipeline(state, add_training_task_metadata=True): else None, ) + def make_training_pipeline_with_no_model_upload(state): return gca_training_pipeline.TrainingPipeline( name=_TEST_PIPELINE_RESOURCE_NAME, state=state, ) + def make_training_pipeline_with_enable_web_access(state): training_pipeline = gca_training_pipeline.TrainingPipeline( name=_TEST_PIPELINE_RESOURCE_NAME, @@ -541,6 +543,7 @@ def make_training_pipeline_with_enable_web_access(state): } return training_pipeline + @pytest.fixture def mock_pipeline_service_get(): with mock.patch.object(