Skip to content

Commit

Permalink
cast web_access_uris to dict type
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Oct 21, 2021
1 parent 26d4977 commit f8b67ea
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
10 changes: 5 additions & 5 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,12 +1108,12 @@ def _get_web_access_uris(self) -> Dict[str, str]:
(Dict[str, str]):
Web access uris of the custom job.
"""
return self._gca_resource.web_access_uris
return dict(self._gca_resource.web_access_uris)

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the custom job"""

for worker, uri in self.web_access_uris.items():
for worker, uri in self._get_web_access_uris().items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for the custom job:\n%s:\n%s"
Expand Down Expand Up @@ -1678,17 +1678,17 @@ def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]:
(Dict[str, Dict[str, str]]):
Web access uris of the hyperparameter job.
"""
web_access_uris = {}
web_access_uris = dict()
for trial in self.trials:
web_access_uris[trial.id] = web_access_uris.get(trial.id, {})
web_access_uris[trial.id] = web_access_uris.get(trial.id, dict())
for worker, uri in trial.web_access_uris.items():
web_access_uris[trial.id][worker] = uri
return web_access_uris

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the hyperparameter job"""

for (trial_id, trial_web_access_uris) in self.web_access_uris.items():
for trial_id, trial_web_access_uris in self._get_web_access_uris().items():
for worker, uri in trial_web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,7 @@ def web_access_uris(self) -> Dict[str, str]:
(Dict[str, str]):
Web access uris of the backing custom job.
"""
web_access_uris = {}
web_access_uris = dict()
if (
self._gca_resource.training_task_metadata
and self._gca_resource.training_task_metadata.get("backingCustomJob")
Expand All @@ -1467,7 +1467,7 @@ def web_access_uris(self) -> Dict[str, str]:
)
custom_job = jobs.CustomJob.get(resource_name=custom_job_resource_name)

web_access_uris = custom_job.web_access_uris
web_access_uris = dict(custom_job.web_access_uris)

return web_access_uris

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,13 @@ def make_training_pipeline(state, add_training_task_metadata=True):
else None,
)


def make_training_pipeline_with_no_model_upload(state):
return gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME, state=state,
)


def make_training_pipeline_with_enable_web_access(state):
training_pipeline = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
Expand All @@ -541,6 +543,7 @@ def make_training_pipeline_with_enable_web_access(state):
}
return training_pipeline


@pytest.fixture
def mock_pipeline_service_get():
with mock.patch.object(
Expand Down

0 comments on commit f8b67ea

Please sign in to comment.