Skip to content

Commit

Permalink
[UX] Support --tail parameter for sky logs (#4241)
Browse files Browse the repository at this point in the history
* support -n parameter for sky logs

* format

* fix format

* format fix again

* fix pylint

* fix format

* rename the -n to --tail

* fix format

* remove -n

* resolve comment

* backward compatiability

* pass yapf

* backward compatiability

* format

* yapf

* restore change and add comment

* moving the comment closer to the place

Co-authored-by: Zhanghao Wu <[email protected]>

* reslove comment

* peek the head instead of loading the whole file to memory

* bug fix

* rephrase function name and comment

* fix

* remove readlines

* reslove comment

---------

Co-authored-by: Zhanghao Wu <[email protected]>
  • Loading branch information
zpoint and Michaelvll authored Nov 9, 2024
1 parent 6fda9fd commit 42c79e1
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 17 deletions.
8 changes: 6 additions & 2 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3711,18 +3711,22 @@ def tail_logs(self,
handle: CloudVmRayResourceHandle,
job_id: Optional[int],
managed_job_id: Optional[int] = None,
follow: bool = True) -> int:
follow: bool = True,
tail: int = 0) -> int:
"""Tail the logs of a job.
Args:
handle: The handle to the cluster.
job_id: The job ID to tail the logs of.
managed_job_id: The managed job ID for display purpose only.
follow: Whether to follow the logs.
tail: The number of lines to display from the end of the
log file. If 0, print all lines.
"""
code = job_lib.JobLibCodeGen.tail_logs(job_id,
managed_job_id=managed_job_id,
follow=follow)
follow=follow,
tail=tail)
if job_id is None and managed_job_id is None:
logger.info(
'Job ID not provided. Streaming the logs of the latest job.')
Expand Down
9 changes: 8 additions & 1 deletion sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,6 +2022,12 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool):
help=('Follow the logs of a job. '
'If --no-follow is specified, print the log so far and exit. '
'[default: --follow]'))
@click.option(
'--tail',
default=0,
type=int,
help=('The number of lines to display from the end of the log file. '
'Default is 0, which means print all lines.'))
@click.argument('cluster',
required=True,
type=str,
Expand All @@ -2035,6 +2041,7 @@ def logs(
sync_down: bool,
status: bool, # pylint: disable=redefined-outer-name
follow: bool,
tail: int,
):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Tail the log of a job.
Expand Down Expand Up @@ -2101,7 +2108,7 @@ def logs(
click.secho(f'Job {id_str}not found', fg='red')
sys.exit(1)

core.tail_logs(cluster, job_id, follow)
core.tail_logs(cluster, job_id, follow, tail)


@cli.command()
Expand Down
5 changes: 3 additions & 2 deletions sky/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,8 @@ def cancel(
@usage_lib.entrypoint
def tail_logs(cluster_name: str,
job_id: Optional[int],
follow: bool = True) -> None:
follow: bool = True,
tail: int = 0) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Tail the logs of a job.
Expand Down Expand Up @@ -775,7 +776,7 @@ def tail_logs(cluster_name: str,
f'{colorama.Style.RESET_ALL}')

usage_lib.record_cluster_name_for_current_operation(cluster_name)
backend.tail_logs(handle, job_id, follow=follow)
backend.tail_logs(handle, job_id, follow=follow, tail=tail)


@usage_lib.entrypoint
Expand Down
2 changes: 1 addition & 1 deletion sky/skylet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# The version of the lib files that skylet/jobs use. Whenever there is an API
# change for the job_lib or log_lib, we need to bump this version, so that the
# user can be notified to update their SkyPilot version on the remote cluster.
SKYLET_LIB_VERSION = 1
SKYLET_LIB_VERSION = 2
SKYLET_VERSION_FILE = '~/.sky/skylet_version'

# `sky jobs dashboard`-related
Expand Down
13 changes: 10 additions & 3 deletions sky/skylet/job_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

logger = sky_logging.init_logger(__name__)

_LINUX_NEW_LINE = '\n'
_JOB_STATUS_LOCK = '~/.sky/locks/.job_{}.lock'


Expand Down Expand Up @@ -602,6 +603,7 @@ def update_job_status(job_ids: List[int],
# the pending table until appearing in ray jobs. For jobs
# submitted outside of the grace period, we will consider the
# ray job status.

if not (pending_job['submit'] > 0 and pending_job['submit'] <
ray_job_query_time - _PENDING_SUBMIT_GRACE_PERIOD):
# Reset the job status to PENDING even though it may not
Expand Down Expand Up @@ -903,14 +905,19 @@ def fail_all_jobs_in_progress(cls) -> str:
def tail_logs(cls,
job_id: Optional[int],
managed_job_id: Optional[int],
follow: bool = True) -> str:
follow: bool = True,
tail: int = 0) -> str:
# pylint: disable=line-too-long

code = [
# We use != instead of is not because 1 is not None will print a warning:
# <stdin>:1: SyntaxWarning: "is not" with a literal. Did you mean "!="?
f'job_id = {job_id} if {job_id} != None else job_lib.get_latest_job_id()',
'run_timestamp = job_lib.get_run_timestamp(job_id)',
f'log_dir = None if run_timestamp is None else os.path.join({constants.SKY_LOGS_DIRECTORY!r}, run_timestamp)',
f'log_lib.tail_logs(job_id=job_id, log_dir=log_dir, '
f'managed_job_id={managed_job_id!r}, follow={follow})',
f'tail_log_kwargs = {{"job_id": job_id, "log_dir": log_dir, "managed_job_id": {managed_job_id!r}, "follow": {follow}}}',
f'{_LINUX_NEW_LINE}if getattr(constants, "SKYLET_LIB_VERSION", 1) > 1: tail_log_kwargs["tail"] = {tail}',
f'{_LINUX_NEW_LINE}log_lib.tail_logs(**tail_log_kwargs)',
]
return cls._build(code)

Expand Down
85 changes: 77 additions & 8 deletions sky/skylet/log_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This is a remote utility module that provides logging functionality.
"""
import collections
import copy
import io
import multiprocessing.pool
Expand All @@ -12,7 +13,8 @@
import tempfile
import textwrap
import time
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import (Deque, Dict, Iterable, Iterator, List, Optional, TextIO,
Tuple, Union)

import colorama

Expand All @@ -26,6 +28,9 @@
_SKY_LOG_WAITING_GAP_SECONDS = 1
_SKY_LOG_WAITING_MAX_RETRY = 5
_SKY_LOG_TAILING_GAP_SECONDS = 0.2
# Peek the head of the lines to check if we need to start
# streaming when tail > 0.
PEEK_HEAD_LINES_FOR_START_STREAM = 20

logger = sky_logging.init_logger(__name__)

Expand Down Expand Up @@ -330,6 +335,7 @@ def run_bash_command_with_log(bash_command: str,

def _follow_job_logs(file,
job_id: int,
start_streaming: bool,
start_streaming_at: str = '') -> Iterator[str]:
"""Yield each line from a file as they are written.
Expand All @@ -338,7 +344,6 @@ def _follow_job_logs(file,
# No need to lock the status here, as the while loop can handle
# the older status.
status = job_lib.get_status_no_lock(job_id)
start_streaming = False
wait_last_logs = True
while True:
tmp = file.readline()
Expand Down Expand Up @@ -378,10 +383,45 @@ def _follow_job_logs(file,
status = job_lib.get_status_no_lock(job_id)


def _peek_head_lines(log_file: TextIO) -> List[str]:
"""Peek the head of the file."""
lines = [
log_file.readline() for _ in range(PEEK_HEAD_LINES_FOR_START_STREAM)
]
# Reset the file pointer to the beginning
log_file.seek(0, os.SEEK_SET)
return [line for line in lines if line]


def _should_stream_the_whole_tail_lines(head_lines_of_log_file: List[str],
tail_lines: Deque[str],
start_stream_at: str) -> bool:
"""Check if the entire tail lines should be streamed."""
# See comment:
# https://github.com/skypilot-org/skypilot/pull/4241#discussion_r1833611567
# for more details.
# Case 1: If start_stream_at is found at the head of the tail lines,
# we should not stream the whole tail lines.
for index, line in enumerate(tail_lines):
if index >= PEEK_HEAD_LINES_FOR_START_STREAM:
break
if start_stream_at in line:
return False
# Case 2: If start_stream_at is found at the head of log file, but not at
# the tail lines, we need to stream the whole tail lines.
for line in head_lines_of_log_file:
if start_stream_at in line:
return True
# Case 3: If start_stream_at is not at the head, and not found at the tail
# lines, we should not stream the whole tail lines.
return False


def tail_logs(job_id: Optional[int],
log_dir: Optional[str],
managed_job_id: Optional[int] = None,
follow: bool = True) -> None:
follow: bool = True,
tail: int = 0) -> None:
"""Tail the logs of a job.
Args:
Expand All @@ -390,6 +430,8 @@ def tail_logs(job_id: Optional[int],
managed_job_id: The managed job id (for logging info only to avoid
confusion).
follow: Whether to follow the logs or print the logs so far and exit.
tail: The number of lines to display from the end of the log file,
if 0, print all lines.
"""
if job_id is None:
# This only happens when job_lib.get_latest_job_id() returns None,
Expand Down Expand Up @@ -430,6 +472,8 @@ def tail_logs(job_id: Optional[int],
status = job_lib.update_job_status([job_id], silent=True)[0]

start_stream_at = 'Waiting for task resources on '
# Explicitly declare the type to avoid mypy warning.
lines: Iterable[str] = []
if follow and status in [
job_lib.JobStatus.SETTING_UP,
job_lib.JobStatus.PENDING,
Expand All @@ -440,18 +484,43 @@ def tail_logs(job_id: Optional[int],
with open(log_path, 'r', newline='', encoding='utf-8') as log_file:
# Using `_follow` instead of `tail -f` to streaming the whole
# log and creating a new process for tail.
start_streaming = False
if tail > 0:
head_lines_of_log_file = _peek_head_lines(log_file)
lines = collections.deque(log_file, maxlen=tail)
start_streaming = _should_stream_the_whole_tail_lines(
head_lines_of_log_file, lines, start_stream_at)
for line in lines:
if start_stream_at in line:
start_streaming = True
if start_streaming:
print(line, end='')
# Flush the last n lines
print(end='', flush=True)
# Now, the cursor is at the end of the last lines
# if tail > 0
for line in _follow_job_logs(log_file,
job_id=job_id,
start_streaming=start_streaming,
start_streaming_at=start_stream_at):
print(line, end='', flush=True)
else:
try:
start_stream = False
with open(log_path, 'r', encoding='utf-8') as f:
for line in f.readlines():
start_streaming = False
with open(log_path, 'r', encoding='utf-8') as log_file:
if tail > 0:
# If tail > 0, we need to read the last n lines.
# We use double ended queue to rotate the last n lines.
head_lines_of_log_file = _peek_head_lines(log_file)
lines = collections.deque(log_file, maxlen=tail)
start_streaming = _should_stream_the_whole_tail_lines(
head_lines_of_log_file, lines, start_stream_at)
else:
lines = log_file
for line in lines:
if start_stream_at in line:
start_stream = True
if start_stream:
start_streaming = True
if start_streaming:
print(line, end='', flush=True)
except FileNotFoundError:
print(f'{colorama.Fore.RED}ERROR: Logs for job {job_id} (status:'
Expand Down

0 comments on commit 42c79e1

Please sign in to comment.