Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Jobs] Allow logs for finished jobs and add sky jobs logs --refresh for restartin jobs controller #4380

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3896,16 +3896,25 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
default=False,
help=('Show the controller logs of this job; useful for debugging '
'launching/recoveries, etc.'))
@click.option(
'--refresh',
'-r',
default=False,
is_flag=True,
required=False,
help='Query the latest job logs, restarting the jobs controller if stopped.'
)
@click.argument('job_id', required=False, type=int)
@usage_lib.entrypoint
def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller: bool):
controller: bool, refresh: bool):
"""Tail the log of a managed job."""
try:
managed_jobs.tail_logs(name=name,
job_id=job_id,
follow=follow,
controller=controller)
controller=controller,
refresh=refresh)
except exceptions.ClusterNotUpError:
with ux_utils.print_exception_no_traceback():
raise
Expand Down
31 changes: 24 additions & 7 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import traceback
import typing
from typing import Tuple
from typing import Optional, Tuple

import filelock

Expand Down Expand Up @@ -87,18 +87,26 @@ def __init__(self, job_id: int, dag_yaml: str,
task.update_envs(task_envs)

def _download_log_and_stream(
self,
handle: cloud_vm_ray_backend.CloudVmRayResourceHandle) -> None:
self, task_id: Optional[int],
handle: Optional[cloud_vm_ray_backend.CloudVmRayResourceHandle]
) -> None:
"""Downloads and streams the logs of the latest job.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Downloads and streams the logs of the latest job.
"""Downloads and streams the logs of the given task id.

is this correct?


We do not stream the logs from the cluster directly, as the
donwload and stream should be faster, and more robust against
preemptions or ssh disconnection during the streaming.
"""
if handle is None:
logger.info(f'Cluster for job {self._job_id} is not found. '
'Skipping downloading and streaming the logs.')
return
managed_job_logs_dir = os.path.join(constants.SKY_LOGS_DIRECTORY,
'managed_jobs')
controller_utils.download_and_stream_latest_job_log(
log_file = controller_utils.download_and_stream_latest_job_log(
self._backend, handle, managed_job_logs_dir)
if log_file is not None:
managed_job_state.set_local_log_file(self._job_id, task_id,
log_file)
logger.info(f'\n== End of logs (ID: {self._job_id}) ==')

def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
Expand Down Expand Up @@ -213,20 +221,29 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = managed_job_utils.get_job_timestamp(
self._backend, cluster_name, get_end_time=True)
# The job is done.
# The job is done. Set the job to SUCCEEDED first before start
# downloading and streaming the logs to make it more responsive.
managed_job_state.set_succeeded(self._job_id,
task_id,
end_time=end_time,
callback_func=callback_func)
logger.info(
f'Managed job {self._job_id} (task: {task_id}) SUCCEEDED. '
f'Cleaning up the cluster {cluster_name}.')
clusters = backend_utils.get_clusters(
cluster_names=[cluster_name],
refresh=False,
include_controller=False)
if clusters:
handle = clusters[0].get('handle')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
handle = clusters[0].get('handle')
assert len(clusters) == 1
handle = clusters[0].get('handle')

# Best effort to download and stream the logs.
self._download_log_and_stream(task_id, handle)
# Only clean up the cluster, not the storages, because tasks may
# share storages.
recovery_strategy.terminate_cluster(cluster_name=cluster_name)
return True

# For single-node jobs, nonterminated job_status indicates a
# For single-node jobs, non-terminated job_status indicates a
# healthy cluster. We can safely continue monitoring.
# For multi-node jobs, since the job may not be set to FAILED
# immediately (depending on user program) when only some of the
Expand Down Expand Up @@ -278,7 +295,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
'The user job failed. Please check the logs below.\n'
f'== Logs of the user job (ID: {self._job_id}) ==\n')

self._download_log_and_stream(handle)
self._download_log_and_stream(task_id, handle)
managed_job_status = (
managed_job_state.ManagedJobStatus.FAILED)
if job_status == job_lib.JobStatus.FAILED_SETUP:
Expand Down
95 changes: 60 additions & 35 deletions sky/jobs/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""SDK functions for managed jobs."""
import os
import tempfile
import typing
from typing import Any, Dict, List, Optional, Union
import uuid

Expand Down Expand Up @@ -29,6 +30,9 @@
from sky.utils import timeline
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky.backends import cloud_vm_ray_backend


@timeline.event
@usage_lib.entrypoint
Expand Down Expand Up @@ -225,6 +229,40 @@ def queue_from_kubernetes_pod(
return jobs


def _maybe_restart_controller(
refresh: bool, stopped_message: str, spinner_message: str
) -> 'cloud_vm_ray_backend.CloudVmRayResourceHandle':
"""Restart controller if refresh is True and it is stopped."""
jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
if refresh:
stopped_message = ''
try:
handle = backend_utils.is_controller_accessible(
controller=jobs_controller_type, stopped_message=stopped_message)
except exceptions.ClusterNotUpError as e:
if not refresh:
raise
handle = None
controller_status = e.cluster_status

if handle is not None:
return handle

sky_logging.print(f'{colorama.Fore.YELLOW}'
f'Restarting {jobs_controller_type.value.name}...'
f'{colorama.Style.RESET_ALL}')

rich_utils.force_update_status(
ux_utils.spinner_message(f'{spinner_message} - restarting '
'controller'))
handle = sky.start(jobs_controller_type.value.cluster_name)
controller_status = status_lib.ClusterStatus.UP
rich_utils.force_update_status(ux_utils.spinner_message(spinner_message))

assert handle is not None, (controller_status, refresh)
return handle


@usage_lib.entrypoint
def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
Expand Down Expand Up @@ -252,34 +290,11 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
does not exist.
RuntimeError: if failed to get the managed jobs with ssh.
"""
jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
stopped_message = ''
if not refresh:
stopped_message = 'No in-progress managed jobs.'
try:
handle = backend_utils.is_controller_accessible(
controller=jobs_controller_type, stopped_message=stopped_message)
except exceptions.ClusterNotUpError as e:
if not refresh:
raise
handle = None
controller_status = e.cluster_status

if refresh and handle is None:
sky_logging.print(f'{colorama.Fore.YELLOW}'
'Restarting controller for latest status...'
f'{colorama.Style.RESET_ALL}')

rich_utils.force_update_status(
ux_utils.spinner_message('Checking managed jobs - restarting '
'controller'))
handle = sky.start(jobs_controller_type.value.cluster_name)
controller_status = status_lib.ClusterStatus.UP
rich_utils.force_update_status(
ux_utils.spinner_message('Checking managed jobs'))

assert handle is not None, (controller_status, refresh)

handle = _maybe_restart_controller(refresh,
stopped_message='No in-progress '
'managed jobs.',
spinner_message='Checking '
'managed jobs')
backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)

Expand Down Expand Up @@ -371,7 +386,7 @@ def cancel(name: Optional[str] = None,

@usage_lib.entrypoint
def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller: bool) -> None:
controller: bool, refresh: bool) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Tail logs of managed jobs.

Expand All @@ -382,15 +397,25 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
"""
# TODO(zhwu): Automatically restart the jobs controller
if name is not None and job_id is not None:
raise ValueError('Cannot specify both name and job_id.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ux_utils?


jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
handle = backend_utils.is_controller_accessible(
controller=jobs_controller_type,
job_name_or_id_str = ''
if job_id is not None:
job_name_or_id_str = str(job_id)
elif name is not None:
job_name_or_id_str = f'-n {name}'
else:
job_name_or_id_str = ''
handle = _maybe_restart_controller(
refresh,
stopped_message=(
'Please restart the jobs controller with '
f'`sky start {jobs_controller_type.value.cluster_name}`.'))
f'{jobs_controller_type.value.name.capitalize()} is stopped. To '
f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs '
f'-r {job_name_or_id_str}{colorama.Style.RESET_ALL}'),
spinner_message='Retrieving job logs')

if name is not None and job_id is not None:
raise ValueError('Cannot specify both name and job_id.')
backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend), backend

Expand Down
26 changes: 25 additions & 1 deletion sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def create_table(cursor, conn):
spot_job_id INTEGER,
task_id INTEGER DEFAULT 0,
task_name TEXT,
specs TEXT)""")
specs TEXT,
local_log_file TEXT DEFAULT NULL)""")
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
conn.commit()

db_utils.add_column_to_table(cursor, conn, 'spot', 'failure_reason', 'TEXT')
Expand Down Expand Up @@ -103,6 +104,8 @@ def create_table(cursor, conn):
value_to_replace_existing_entries=json.dumps({
'max_restarts_on_errors': 0,
}))
db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file',
'TEXT DEFAULT NULL')

# `job_info` contains the mapping from job_id to the job_name.
# In the future, it may contain more information about each job.
Expand Down Expand Up @@ -157,6 +160,7 @@ def _get_db_path() -> str:
'task_id',
'task_name',
'specs',
'local_log_file',
# columns from the job_info table
'_job_info_job_id', # This should be the same as job_id
'job_name',
Expand Down Expand Up @@ -512,6 +516,16 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
callback_func('CANCELLED')


def set_local_log_file(job_id: int, task_id: Optional[int],
local_log_file: str):
"""Set the local log file for a job."""
task_str = '' if task_id is None else f' AND task_id={task_id}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure do we need to pass in task_id as ?.. Actually, is there any difference?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should just use ? everywhere to avoid having to reason about where injection may be possible.

with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
'UPDATE spot SET local_log_file=(?) '
f'WHERE spot_job_id=(?){task_str}', (local_log_file, job_id))


# ======== utility functions ========
def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
"""Get non-terminal job ids by name."""
Expand Down Expand Up @@ -662,3 +676,13 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
WHERE spot_job_id=(?) AND task_id=(?)""",
(job_id, task_id)).fetchone()
return json.loads(task_specs[0])


def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
"""Get the local log directory for a job."""
task_str = '' if task_id is None else f' AND task_id={task_id}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with db_utils.safe_cursor(_DB_PATH) as cursor:
local_log_file = cursor.execute(
f'SELECT local_log_file FROM spot '
f'WHERE spot_job_id=(?){task_str}', (job_id,)).fetchone()
return local_log_file[-1] if local_log_file else None
18 changes: 16 additions & 2 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,24 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
if managed_job_status.is_failed():
job_msg = ('\nFailure reason: '
f'{managed_job_state.get_failure_reason(job_id)}')
log_file = managed_job_state.get_local_log_file(job_id, None)
if log_file is not None:
with open(log_file, 'r', encoding='utf-8') as f:
# Stream the logs to the console without reading the whole
# file into memory.
start_streaming = False
for line in f:
if log_lib.LOG_FILE_START_STREAMING_AT in line:
start_streaming = True
if start_streaming:
print(line, end='', flush=True)
return ''
return (f'{colorama.Fore.YELLOW}'
f'Job {job_id} is already in terminal state '
f'{managed_job_status.value}. Logs will not be shown.'
f'{colorama.Style.RESET_ALL}{job_msg}')
f'{managed_job_status.value}. For more details, run: '
f'sky jobs logs --controller {job_id}'
f'{colorama.Style.RESET_ALL}'
f'{job_msg}')
backend = backends.CloudVmRayBackend()
task_id, managed_job_status = (
managed_job_state.get_latest_task_id_status(job_id))
Expand Down
4 changes: 3 additions & 1 deletion sky/skylet/log_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

logger = sky_logging.init_logger(__name__)

LOG_FILE_START_STREAMING_AT = 'Waiting for task resources on '


class _ProcessingArgs:
"""Arguments for processing logs."""
Expand Down Expand Up @@ -435,7 +437,7 @@ def tail_logs(job_id: Optional[int],
time.sleep(_SKY_LOG_WAITING_GAP_SECONDS)
status = job_lib.update_job_status([job_id], silent=True)[0]

start_stream_at = 'Waiting for task resources on '
start_stream_at = LOG_FILE_START_STREAMING_AT
# Explicitly declare the type to avoid mypy warning.
lines: Iterable[str] = []
if follow and status in [
Expand Down
3 changes: 3 additions & 0 deletions sky/skylet/log_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ from sky.skylet import constants as constants
from sky.skylet import job_lib as job_lib
from sky.utils import log_utils as log_utils

LOG_FILE_START_STREAMING_AT: str = ...


class _ProcessingArgs:
log_path: str
stream_logs: bool
Expand Down
11 changes: 9 additions & 2 deletions sky/utils/controller_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sky.serve import constants as serve_constants
from sky.serve import serve_utils
from sky.skylet import constants
from sky.skylet import log_lib
from sky.utils import common_utils
from sky.utils import env_options
from sky.utils import rich_utils
Expand Down Expand Up @@ -380,11 +381,17 @@ def download_and_stream_latest_job_log(
else:
log_dir = list(log_dirs.values())[0]
log_file = os.path.join(log_dir, 'run.log')

# Print the logs to the console.
try:
with open(log_file, 'r', encoding='utf-8') as f:
print(f.read())
# Stream the logs to the console without reading the whole
# file into memory.
start_streaming = False
for line in f:
if log_lib.LOG_FILE_START_STREAMING_AT in line:
start_streaming = True
if start_streaming:
print(line, end='', flush=True)
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +387 to +394
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this into some utils file?

except FileNotFoundError:
logger.error('Failed to find the logs for the user '
f'program at {log_file}.')
Expand Down
Loading