Skip to content

Commit

Permalink
Feat: Add debugging terminal support for CustomJob, HyperparameterTun… (
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu authored Oct 22, 2021
1 parent 6c7e1fa commit 2deb505
Show file tree
Hide file tree
Showing 6 changed files with 788 additions and 21 deletions.
155 changes: 139 additions & 16 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ def _dashboard_uri(self) -> Optional[str]:
url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}"
return url

def _log_job_state(self):
"""Helper method to log job state."""
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)

def _block_until_complete(self):
"""Helper method to block and check on job until complete.
Expand All @@ -190,26 +201,13 @@ def _block_until_complete(self):
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
time.sleep(wait)

_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
self._log_job_state()

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
Expand Down Expand Up @@ -845,6 +843,63 @@ def __init__(
project=project, location=location
)

self._logged_web_access_uris = set()

@property
def web_access_uris(self) -> Dict[str, Union[str, Dict[str, str]]]:
"""Fetch the runnable job again and return the latest web access uris.
Returns:
(Dict[str, Union[str, Dict[str, str]]]):
Web access uris of the runnable job.
"""

# Fetch the Job again for most up-to-date web access uris
self._sync_gca_resource()
return self._get_web_access_uris()

@abc.abstractmethod
def _get_web_access_uris(self):
"""Helper method to get the web access uris of the runnable job"""
pass

@abc.abstractmethod
def _log_web_access_uris(self):
"""Helper method to log the web access uris of the runnable job"""
pass

def _block_until_complete(self):
"""Helper method to block and check on runnable job until complete.
Raises:
RuntimeError: If job failed or cancelled.
"""

# Used these numbers so failures surface fast
wait = 5 # start at five seconds
log_wait = 5
max_wait = 60 * 5 # 5 minute wait
multiplier = 2 # scale wait by 2 every iteration

previous_time = time.time()
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
self._log_web_access_uris()
time.sleep(wait)

self._log_job_state()

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
else:
_LOGGER.log_action_completed_against_resource("run", "completed", self)

@abc.abstractmethod
def run(self) -> None:
pass
Expand Down Expand Up @@ -1046,6 +1101,26 @@ def network(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return self._gca_resource.job_spec.network

def _get_web_access_uris(self) -> Dict[str, str]:
"""Helper method to get the web access uris of the custom job
Returns:
(Dict[str, str]):
Web access uris of the custom job.
"""
return dict(self._gca_resource.web_access_uris)

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the custom job"""

for worker, uri in self._get_web_access_uris().items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for the custom job:\n%s:\n%s"
% (self.__class__.__name__, self._gca_resource.name, worker, uri,),
)
self._logged_web_access_uris.add(uri)

@classmethod
def from_local_script(
cls,
Expand Down Expand Up @@ -1250,6 +1325,7 @@ def run(
network: Optional[str] = None,
timeout: Optional[int] = None,
restart_job_on_worker_restart: bool = False,
enable_web_access: bool = False,
tensorboard: Optional[str] = None,
sync: bool = True,
) -> None:
Expand All @@ -1271,6 +1347,10 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Whether you want Vertex AI to enable interactive shell access
to training containers.
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
Expand Down Expand Up @@ -1304,6 +1384,9 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)

if enable_web_access:
self._gca_resource.job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
v1beta1_gca_resource._pb.MergeFromString(
Expand Down Expand Up @@ -1588,13 +1671,46 @@ def network(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return getattr(self._gca_resource.trial_job_spec, "network")

def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]:
"""Helper method to get the web access uris of the hyperparameter job
Returns:
(Dict[str, Dict[str, str]]):
Web access uris of the hyperparameter job.
"""
web_access_uris = dict()
for trial in self.trials:
web_access_uris[trial.id] = web_access_uris.get(trial.id, dict())
for worker, uri in trial.web_access_uris.items():
web_access_uris[trial.id][worker] = uri
return web_access_uris

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the hyperparameter job"""

for trial_id, trial_web_access_uris in self._get_web_access_uris().items():
for worker, uri in trial_web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for trial - %s:\n%s:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
trial_id,
worker,
uri,
),
)
self._logged_web_access_uris.add(uri)

@base.optional_sync()
def run(
self,
service_account: Optional[str] = None,
network: Optional[str] = None,
timeout: Optional[int] = None, # seconds
restart_job_on_worker_restart: bool = False,
enable_web_access: bool = False,
tensorboard: Optional[str] = None,
sync: bool = True,
) -> None:
Expand All @@ -1616,6 +1732,10 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Whether you want Vertex AI to enable interactive shell access
to training containers.
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
Expand Down Expand Up @@ -1649,6 +1769,9 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)

if enable_web_access:
self._gca_resource.trial_job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = (
gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
Expand Down
Loading

0 comments on commit 2deb505

Please sign in to comment.