diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 68d9bfecfb0..0198ca45d71 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3711,7 +3711,8 @@ def tail_logs(self, handle: CloudVmRayResourceHandle, job_id: Optional[int], managed_job_id: Optional[int] = None, - follow: bool = True) -> int: + follow: bool = True, + tail: int = 0) -> int: """Tail the logs of a job. Args: @@ -3719,10 +3720,13 @@ def tail_logs(self, job_id: The job ID to tail the logs of. managed_job_id: The managed job ID for display purpose only. follow: Whether to follow the logs. + tail: The number of lines to display from the end of the + log file. If 0, print all lines. """ code = job_lib.JobLibCodeGen.tail_logs(job_id, managed_job_id=managed_job_id, - follow=follow) + follow=follow, + tail=tail) if job_id is None and managed_job_id is None: logger.info( 'Job ID not provided. Streaming the logs of the latest job.') diff --git a/sky/cli.py b/sky/cli.py index 29e3b2e51cf..c74b2f3ad5d 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -2022,6 +2022,12 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool): help=('Follow the logs of a job. ' 'If --no-follow is specified, print the log so far and exit. ' '[default: --follow]')) +@click.option( + '--tail', + default=0, + type=int, + help=('The number of lines to display from the end of the log file. ' + 'Default is 0, which means print all lines.')) @click.argument('cluster', required=True, type=str, @@ -2035,6 +2041,7 @@ def logs( sync_down: bool, status: bool, # pylint: disable=redefined-outer-name follow: bool, + tail: int, ): # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Tail the log of a job. @@ -2101,7 +2108,7 @@ def logs( click.secho(f'Job {id_str}not found', fg='red') sys.exit(1) - core.tail_logs(cluster, job_id, follow) + core.tail_logs(cluster, job_id, follow, tail) @cli.command() diff --git a/sky/core.py b/sky/core.py index 496b8b8ad5e..4bb12f4a21a 100644 --- a/sky/core.py +++ b/sky/core.py @@ -742,7 +742,8 @@ def cancel( @usage_lib.entrypoint def tail_logs(cluster_name: str, job_id: Optional[int], - follow: bool = True) -> None: + follow: bool = True, + tail: int = 0) -> None: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Tail the logs of a job. @@ -775,7 +776,7 @@ def tail_logs(cluster_name: str, f'{colorama.Style.RESET_ALL}') usage_lib.record_cluster_name_for_current_operation(cluster_name) - backend.tail_logs(handle, job_id, follow=follow) + backend.tail_logs(handle, job_id, follow=follow, tail=tail) @usage_lib.entrypoint diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index a9b8013cad7..91476cf8f6f 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -79,7 +79,7 @@ # The version of the lib files that skylet/jobs use. Whenever there is an API # change for the job_lib or log_lib, we need to bump this version, so that the # user can be notified to update their SkyPilot version on the remote cluster. -SKYLET_LIB_VERSION = 1 +SKYLET_LIB_VERSION = 2 SKYLET_VERSION_FILE = '~/.sky/skylet_version' # `sky jobs dashboard`-related diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index fba015618f2..02040ac705a 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -29,6 +29,7 @@ logger = sky_logging.init_logger(__name__) +_LINUX_NEW_LINE = '\n' _JOB_STATUS_LOCK = '~/.sky/locks/.job_{}.lock' @@ -602,6 +603,7 @@ def update_job_status(job_ids: List[int], # the pending table until appearing in ray jobs. For jobs # submitted outside of the grace period, we will consider the # ray job status. + if not (pending_job['submit'] > 0 and pending_job['submit'] < ray_job_query_time - _PENDING_SUBMIT_GRACE_PERIOD): # Reset the job status to PENDING even though it may not @@ -903,14 +905,19 @@ def fail_all_jobs_in_progress(cls) -> str: def tail_logs(cls, job_id: Optional[int], managed_job_id: Optional[int], - follow: bool = True) -> str: + follow: bool = True, + tail: int = 0) -> str: # pylint: disable=line-too-long + code = [ + # We use != instead of is not because 1 is not None will print a warning: + # :1: SyntaxWarning: "is not" with a literal. Did you mean "!="? f'job_id = {job_id} if {job_id} != None else job_lib.get_latest_job_id()', 'run_timestamp = job_lib.get_run_timestamp(job_id)', f'log_dir = None if run_timestamp is None else os.path.join({constants.SKY_LOGS_DIRECTORY!r}, run_timestamp)', - f'log_lib.tail_logs(job_id=job_id, log_dir=log_dir, ' - f'managed_job_id={managed_job_id!r}, follow={follow})', + f'tail_log_kwargs = {{"job_id": job_id, "log_dir": log_dir, "managed_job_id": {managed_job_id!r}, "follow": {follow}}}', + f'{_LINUX_NEW_LINE}if getattr(constants, "SKYLET_LIB_VERSION", 1) > 1: tail_log_kwargs["tail"] = {tail}', + f'{_LINUX_NEW_LINE}log_lib.tail_logs(**tail_log_kwargs)', ] return cls._build(code) diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index eb64440077e..391fa8c4fe5 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -2,6 +2,7 @@ This is a remote utility module that provides logging functionality. """ +import collections import copy import io import multiprocessing.pool @@ -12,7 +13,8 @@ import tempfile import textwrap import time -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import (Deque, Dict, Iterable, Iterator, List, Optional, TextIO, + Tuple, Union) import colorama @@ -26,6 +28,9 @@ _SKY_LOG_WAITING_GAP_SECONDS = 1 _SKY_LOG_WAITING_MAX_RETRY = 5 _SKY_LOG_TAILING_GAP_SECONDS = 0.2 +# Peek the head of the lines to check if we need to start +# streaming when tail > 0. +PEEK_HEAD_LINES_FOR_START_STREAM = 20 logger = sky_logging.init_logger(__name__) @@ -330,6 +335,7 @@ def run_bash_command_with_log(bash_command: str, def _follow_job_logs(file, job_id: int, + start_streaming: bool, start_streaming_at: str = '') -> Iterator[str]: """Yield each line from a file as they are written. @@ -338,7 +344,6 @@ def _follow_job_logs(file, # No need to lock the status here, as the while loop can handle # the older status. status = job_lib.get_status_no_lock(job_id) - start_streaming = False wait_last_logs = True while True: tmp = file.readline() @@ -378,10 +383,45 @@ def _follow_job_logs(file, status = job_lib.get_status_no_lock(job_id) +def _peek_head_lines(log_file: TextIO) -> List[str]: + """Peek the head of the file.""" + lines = [ + log_file.readline() for _ in range(PEEK_HEAD_LINES_FOR_START_STREAM) + ] + # Reset the file pointer to the beginning + log_file.seek(0, os.SEEK_SET) + return [line for line in lines if line] + + +def _should_stream_the_whole_tail_lines(head_lines_of_log_file: List[str], + tail_lines: Deque[str], + start_stream_at: str) -> bool: + """Check if the entire tail lines should be streamed.""" + # See comment: + # https://github.com/skypilot-org/skypilot/pull/4241#discussion_r1833611567 + # for more details. + # Case 1: If start_stream_at is found at the head of the tail lines, + # we should not stream the whole tail lines. + for index, line in enumerate(tail_lines): + if index >= PEEK_HEAD_LINES_FOR_START_STREAM: + break + if start_stream_at in line: + return False + # Case 2: If start_stream_at is found at the head of log file, but not at + # the tail lines, we need to stream the whole tail lines. + for line in head_lines_of_log_file: + if start_stream_at in line: + return True + # Case 3: If start_stream_at is not at the head, and not found at the tail + # lines, we should not stream the whole tail lines. + return False + + def tail_logs(job_id: Optional[int], log_dir: Optional[str], managed_job_id: Optional[int] = None, - follow: bool = True) -> None: + follow: bool = True, + tail: int = 0) -> None: """Tail the logs of a job. Args: @@ -390,6 +430,8 @@ def tail_logs(job_id: Optional[int], managed_job_id: The managed job id (for logging info only to avoid confusion). follow: Whether to follow the logs or print the logs so far and exit. + tail: The number of lines to display from the end of the log file, + if 0, print all lines. """ if job_id is None: # This only happens when job_lib.get_latest_job_id() returns None, @@ -430,6 +472,8 @@ def tail_logs(job_id: Optional[int], status = job_lib.update_job_status([job_id], silent=True)[0] start_stream_at = 'Waiting for task resources on ' + # Explicitly declare the type to avoid mypy warning. + lines: Iterable[str] = [] if follow and status in [ job_lib.JobStatus.SETTING_UP, job_lib.JobStatus.PENDING, @@ -440,18 +484,43 @@ def tail_logs(job_id: Optional[int], with open(log_path, 'r', newline='', encoding='utf-8') as log_file: # Using `_follow` instead of `tail -f` to streaming the whole # log and creating a new process for tail. + start_streaming = False + if tail > 0: + head_lines_of_log_file = _peek_head_lines(log_file) + lines = collections.deque(log_file, maxlen=tail) + start_streaming = _should_stream_the_whole_tail_lines( + head_lines_of_log_file, lines, start_stream_at) + for line in lines: + if start_stream_at in line: + start_streaming = True + if start_streaming: + print(line, end='') + # Flush the last n lines + print(end='', flush=True) + # Now, the cursor is at the end of the last lines + # if tail > 0 for line in _follow_job_logs(log_file, job_id=job_id, + start_streaming=start_streaming, start_streaming_at=start_stream_at): print(line, end='', flush=True) else: try: - start_stream = False - with open(log_path, 'r', encoding='utf-8') as f: - for line in f.readlines(): + start_streaming = False + with open(log_path, 'r', encoding='utf-8') as log_file: + if tail > 0: + # If tail > 0, we need to read the last n lines. + # We use double ended queue to rotate the last n lines. + head_lines_of_log_file = _peek_head_lines(log_file) + lines = collections.deque(log_file, maxlen=tail) + start_streaming = _should_stream_the_whole_tail_lines( + head_lines_of_log_file, lines, start_stream_at) + else: + lines = log_file + for line in lines: if start_stream_at in line: - start_stream = True - if start_stream: + start_streaming = True + if start_streaming: print(line, end='', flush=True) except FileNotFoundError: print(f'{colorama.Fore.RED}ERROR: Logs for job {job_id} (status:'