Skip to content

Commit

Permalink
Feat: Add debugging terminal support for CustomJob, HyperparameterTun…
Browse files Browse the repository at this point in the history
…ingJob, and Custom(*)TrainingJob
  • Loading branch information
morgandu committed Oct 19, 2021
1 parent 208889b commit 0fbb148
Show file tree
Hide file tree
Showing 6 changed files with 802 additions and 21 deletions.
160 changes: 144 additions & 16 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ def _dashboard_uri(self) -> Optional[str]:
url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}"
return url

def _log_job_state(self):
"""Helper method to log job state."""
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)

def _block_until_complete(self):
"""Helper method to block and check on job until complete.
Expand All @@ -190,26 +201,13 @@ def _block_until_complete(self):
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
time.sleep(wait)

_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
self._log_job_state()

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
Expand Down Expand Up @@ -845,6 +843,66 @@ def __init__(
project=project, location=location
)

self._web_access_uris = None
self._logged_web_access_uris = []

@property
def web_access_uris(self) -> Dict[str, str]:
"""Fetch the runnable job again and return the latest web access uris.
Returns:
(Dict[str, str]):
Web access uris of the runnable job.
"""

# Fetch the Job again for most up-to-date web access uris
self._sync_gca_resource()
self._get_web_access_uris()

return self._web_access_uris

@abc.abstractmethod
def _get_web_access_uris(self):
"""Helper method to get the web access uris of the runnable job"""
pass

@abc.abstractmethod
def _log_web_access_uris(self):
"""Helper method to log the web access uris of the runnable job"""
pass

def _block_until_complete(self):
"""Helper method to block and check on runnable job until complete.
Raises:
RuntimeError: If job failed or cancelled.
"""

# Used these numbers so failures surface fast
wait = 5 # start at five seconds
log_wait = 5
max_wait = 60 * 5 # 5 minute wait
multiplier = 2 # scale wait by 2 every iteration

previous_time = time.time()
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
self._log_web_access_uris()
time.sleep(wait)

self._log_job_state()

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
else:
_LOGGER.log_action_completed_against_resource("run", "completed", self)

@abc.abstractmethod
def run(self) -> None:
pass
Expand Down Expand Up @@ -1046,6 +1104,29 @@ def network(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return self._gca_resource.job_spec.network

def _get_web_access_uris(self):
"""Helper method to get the web access uris of the custom job"""
self._web_access_uris = self._gca_resource.web_access_uris

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the custom job"""

self._get_web_access_uris()

if self._web_access_uris:
for worker, uri in self._web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for the custom job:\n%s:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
worker,
uri,
),
)
self._logged_web_access_uris.append(uri)

@classmethod
def from_local_script(
cls,
Expand Down Expand Up @@ -1226,6 +1307,7 @@ def run(
network: Optional[str] = None,
timeout: Optional[int] = None,
restart_job_on_worker_restart: bool = False,
enable_web_access: bool = False,
tensorboard: Optional[str] = None,
sync: bool = True,
) -> None:
Expand All @@ -1247,6 +1329,10 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Whether you want Vertex AI to enable interactive shell access
to training containers.
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
Expand Down Expand Up @@ -1280,6 +1366,9 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)

if enable_web_access:
self._gca_resource.job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
v1beta1_gca_resource._pb.MergeFromString(
Expand Down Expand Up @@ -1564,13 +1653,45 @@ def network(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return getattr(self._gca_resource.trial_job_spec, "network")

def _get_web_access_uris(self):
"""Helper method to get the web access uris of the hyperparameter job"""
if self.trials:
self._web_access_uris = [
(trial.id, trial.web_access_uris)
for trial in self.trials
if trial.web_access_uris
]

def _log_web_access_uris(self):
"""Helper method to log the web access uris of the hyperparameter job"""

self._get_web_access_uris()

if self._web_access_uris:
for (trial_id, tria_web_access_uris) in self._web_access_uris:
if tria_web_access_uris:
for worker, uri in tria_web_access_uris.items():
if uri not in self._logged_web_access_uris:
_LOGGER.info(
"%s %s access the interactive shell terminals for trial - %s:\n%s:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
trial_id,
worker,
uri,
),
)
self._logged_web_access_uris.append(uri)

@base.optional_sync()
def run(
self,
service_account: Optional[str] = None,
network: Optional[str] = None,
timeout: Optional[int] = None, # seconds
restart_job_on_worker_restart: bool = False,
enable_web_access: bool = False,
tensorboard: Optional[str] = None,
sync: bool = True,
) -> None:
Expand All @@ -1592,6 +1713,10 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Whether you want Vertex AI to enable interactive shell access
to training containers.
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
Expand Down Expand Up @@ -1625,6 +1750,9 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)

if enable_web_access:
self._gca_resource.trial_job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = (
gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
Expand Down
Loading

0 comments on commit 0fbb148

Please sign in to comment.