Skip to content

Commit

Permalink
Feat: Add debugging terminal support for CustomJob, HyperparameterTun…
Browse files Browse the repository at this point in the history
…ingJob, and Custom(*)TrainingJob
  • Loading branch information
morgandu committed Sep 30, 2021
1 parent 09e48de commit b71b3fb
Show file tree
Hide file tree
Showing 5 changed files with 766 additions and 21 deletions.
153 changes: 137 additions & 16 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
gca_job_state.JobState.JOB_STATE_CANCELLED,
)

_JOB_RUNNING_STATE = (gca_job_state.JobState.JOB_STATE_RUNNING,)
_TRIAL_ACTIVE_STATE = (gca_study_compat.Trial.State.ACTIVE,)


class _Job(base.VertexAiResourceNounWithFutureManager):
"""Class that represents a general Job resource in Vertex AI.
Expand Down Expand Up @@ -173,6 +176,16 @@ def _dashboard_uri(self) -> Optional[str]:
url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}"
return url

def _log_job_state(self):
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)

def _block_until_complete(self):
"""Helper method to block and check on job until complete.
Expand All @@ -190,26 +203,13 @@ def _block_until_complete(self):
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
time.sleep(wait)

_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
self._log_job_state()

# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
Expand Down Expand Up @@ -1219,13 +1219,60 @@ def from_local_script(
staging_bucket=staging_bucket,
)

def web_access_uris(self):
if self.state in _JOB_COMPLETE_STATES:
self._log_job_state()
_LOGGER.info(
"Access to the interactive shell terminals are only available when the job is running."
)
return
self._web_access_uris()

def _web_access_uris(self):
"""Helper method to check on job to get web access uris.
"""

# Used these numbers so failures surface fast
wait = 5 # start at five seconds
log_wait = 5
max_wait = 60 * 1 # 1 minute wait
multiplier = 2 # scale wait by 2 every iteration

previous_time = time.time()
while (
self.state not in _JOB_RUNNING_STATE
and self.state not in _JOB_COMPLETE_STATES
):
current_time = time.time()
if current_time - previous_time >= log_wait:
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
time.sleep(wait)

if self.state in _JOB_RUNNING_STATE:
_LOGGER.info(
"%s %s access the interactive shell terminals for this job at the following links:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
"\n".join(
[
"%s:\n%s" % (worker, uri)
for worker, uri in self._gca_resource.web_access_uris.items()
]
),
)
)

@base.optional_sync()
def run(
self,
service_account: Optional[str] = None,
network: Optional[str] = None,
timeout: Optional[int] = None,
restart_job_on_worker_restart: bool = False,
enable_web_access: Optional[bool] = False,
tensorboard: Optional[str] = None,
sync: bool = True,
) -> None:
Expand All @@ -1247,6 +1294,10 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Optional. Whether you want Vertex AI to enable interactive shell access
to training containers.
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
Expand Down Expand Up @@ -1280,6 +1331,9 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)

if enable_web_access:
self._gca_resource.job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
v1beta1_gca_resource._pb.MergeFromString(
Expand Down Expand Up @@ -1309,6 +1363,9 @@ def run(
)
)

if enable_web_access:
self._web_access_uris()

self._block_until_complete()

@property
Expand Down Expand Up @@ -1564,13 +1621,67 @@ def network(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return getattr(self._gca_resource.trial_job_spec, "network")

def _web_access_uris(self):
"""Helper method to check on hp job to get web access uris.
"""

current_trial_id = None
current_trial = None

# Used these numbers so failures surface fast
wait = 5 # start at five seconds
log_wait = 5
max_wait_short = 60 * 1 # 1 minute wait
max_wait_long = 60 * 5 # 5 minute wait
max_wait = max_wait_short
multiplier = 2 # scale wait by 2 every iteration

previous_time = time.time()
while self.state not in _JOB_COMPLETE_STATES:

if self.state in _JOB_RUNNING_STATE and self._gca_resource.trials:

current_trial = self._gca_resource.trials[-1]
if (
current_trial.state in _TRIAL_ACTIVE_STATE
and current_trial_id != current_trial.id
):
current_trial_id = current_trial.id
_LOGGER.info(
"%s %s access the interactive shell terminals for trial %s at the following links:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
current_trial.id,
"\n".join(
[
"%s:\n%s" % (worker, uri)
for worker, uri in current_trial.web_access_uris.items()
]
),
)
)
if current_trial.state in _TRIAL_ACTIVE_STATE:
max_wait = max_wait_long
else:
max_wait = max_wait_short
log_wait = 5

current_time = time.time()
if current_time - previous_time >= log_wait:
self._log_job_state()
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
time.sleep(wait)

@base.optional_sync()
def run(
self,
service_account: Optional[str] = None,
network: Optional[str] = None,
timeout: Optional[int] = None, # seconds
restart_job_on_worker_restart: bool = False,
enable_web_access: Optional[bool] = False,
tensorboard: Optional[str] = None,
sync: bool = True,
) -> None:
Expand All @@ -1592,6 +1703,10 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Optional. Whether you want Vertex AI to enable interactive shell access
to training containers.
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
Expand Down Expand Up @@ -1625,6 +1740,9 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)

if enable_web_access:
self._gca_resource.trial_job_spec.enable_web_access = enable_web_access

if tensorboard:
v1beta1_gca_resource = (
gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
Expand Down Expand Up @@ -1658,6 +1776,9 @@ def run(
)
)

if enable_web_access:
self._web_access_uris()

self._block_until_complete()

@property
Expand Down
Loading

0 comments on commit b71b3fb

Please sign in to comment.