Skip to content

Commit

Permalink
Misc improvements to display=plain (#1166)
Browse files Browse the repository at this point in the history
- Handle multiple models and multiple tasks (display task/model to disambiguate if required)
- Pad display so columns always line up
- Throttle updates to once every second
- Condense task panel when there is no body or footer
- Call rich_initialise in PlainDisplay constructor so ansi colors, etc. are disabled globally for rich printing.
- Remove spurious plain print from rich display (since plain now has its own display handler)
  • Loading branch information
jjallaire authored Jan 20, 2025
1 parent 56f7b6f commit ac268d1
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Log warning when a non-fatal sample error occurs (i.e. errors permitted by the `fail_on_error` option)
- Inspect View: allow filtering samples by compound expressions including multiple scorers. (thanks @andrei-apollo)
- Inspect View: improve rendering performance and stability for the viewer when viewing very large eval logs or samples with a large number of steps.
- Task display: Improved `plain` mode with periodic updates on progress, metrics, etc.

## v0.3.58 (16 January 2025)

Expand Down
10 changes: 7 additions & 3 deletions src/inspect_ai/_display/core/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def task_panel(
table.add_row(subtitle_table)

# main progress and task info
table.add_row()
table.add_row(body)
table.add_row()
if body:
table.add_row()
table.add_row(body)

# spacing if there is more ocontent
if footer or log_location:
table.add_row()

# footer if specified
if footer:
Expand Down
48 changes: 37 additions & 11 deletions src/inspect_ai/_display/plain/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

import rich

from inspect_ai._display.core.rich import rich_initialise
from inspect_ai._util.text import truncate
from inspect_ai._util.throttle import throttle

from ...util._concurrency import concurrency_status
from ..core.config import task_config
from ..core.display import (
Expand All @@ -28,6 +32,7 @@ def __init__(self) -> None:
self.total_tasks: int = 0
self.tasks: list[TaskWithResult] = []
self.parallel = False
rich_initialise()

def print(self, message: str) -> None:
print(message)
Expand All @@ -48,6 +53,8 @@ async def task_screen(
self, tasks: list[TaskSpec], parallel: bool
) -> AsyncIterator[TaskScreen]:
self.total_tasks = len(tasks)
self.multiple_task_names = len({task.name for task in tasks}) > 1
self.multiple_model_names = len({str(task.model) for task in tasks}) > 1
self.tasks = []
self.parallel = parallel
try:
Expand All @@ -58,7 +65,6 @@ async def task_screen(
finally:
# Print final results
if self.tasks:
print("\nResults:")
self._print_results()

@contextlib.contextmanager
Expand All @@ -72,13 +78,16 @@ def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]:
footer=None,
log_location=None,
)
print("Running task:")
rich.print(panel)

# Create and yield task display
task = TaskWithResult(profile, None)
self.tasks.append(task)
yield PlainTaskDisplay(task)
yield PlainTaskDisplay(
task,
show_task_names=self.multiple_task_names,
show_model_names=self.multiple_model_names,
)

def _print_results(self) -> None:
"""Print final results using rich panels"""
Expand All @@ -100,8 +109,12 @@ def complete(self) -> None:


class PlainTaskDisplay(TaskDisplay):
def __init__(self, task: TaskWithResult):
def __init__(
self, task: TaskWithResult, *, show_task_names: bool, show_model_names: bool
):
self.task = task
self.show_task_names = show_task_names
self.show_model_names = show_model_names
self.progress_display: PlainProgress | None = None
self.samples_complete = 0
self.samples_total = 0
Expand All @@ -113,6 +126,10 @@ def progress(self) -> Iterator[Progress]:
self.progress_display = PlainProgress(self.task.profile.steps)
yield self.progress_display

@throttle(1)
def _print_status_throttled(self) -> None:
self._print_status()

def _print_status(self) -> None:
"""Print status updates on new lines when there's meaningful progress"""
if not self.progress_display:
Expand All @@ -125,16 +142,25 @@ def _print_status(self) -> None:

# Only print on percentage changes to avoid too much output
if current_progress != self.last_progress:
status_parts = []
status_parts: list[str] = []

# if this is parallel print task and model to distinguish (limit both to 12 chars)
MAX_NAME_WIDTH = 12
if self.show_task_names:
status_parts.append(truncate(self.task.profile.name, MAX_NAME_WIDTH))
if self.show_model_names:
status_parts.append(
truncate(str(self.task.profile.model), MAX_NAME_WIDTH)
)

# Add step progress
status_parts.append(
f"Steps: {self.progress_display.current}/{self.progress_display.total} ({current_progress}%)"
f"Steps: {self.progress_display.current:3d}/{self.progress_display.total} {current_progress:3d}%"
)

# Add sample progress
status_parts.append(
f"Samples: {self.samples_complete}/{self.samples_total}"
f"Samples: {self.samples_complete:3d}/{self.samples_total:3d}"
)

# Add metrics
Expand All @@ -147,7 +173,7 @@ def _print_status(self) -> None:
# the rich formatting added in the ``task_dict`` call
resources_dict: dict[str, str] = {}
for model, resource in concurrency_status().items():
resources_dict[model] = f"{resource[0]}/{resource[1]}"
resources_dict[model] = f"{resource[0]:2d}/{resource[1]:2d}"
resources = "".join(
[f"{key}: {value}" for key, value in resources_dict.items()]
)
Expand All @@ -166,12 +192,12 @@ def _print_status(self) -> None:
def sample_complete(self, complete: int, total: int) -> None:
self.samples_complete = complete
self.samples_total = total
self._print_status()
self._print_status_throttled()

def update_metrics(self, metrics: list[TaskDisplayMetric]) -> None:
self.current_metrics = metrics
self._print_status()
self._print_status_throttled()

def complete(self, result: TaskResult) -> None:
self.task.result = result
print("Task complete.")
self._print_status()
5 changes: 0 additions & 5 deletions src/inspect_ai/_display/rich/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ async def task_screen(
@override
@contextlib.contextmanager
def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]:
# if there is no ansi display than all of the below will
# be a no-op, so we print a simple text message for the task
if display_type() == "plain":
rich.get_console().print(task_no_ansi(profile))

# for typechekcer
if self.tasks is None:
self.tasks = []
Expand Down
23 changes: 23 additions & 0 deletions src/inspect_ai/_util/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,26 @@ def str_to_float(s: str) -> float:
exponent = 1 # Default exponent is 1 if no superscript is present

return base**exponent


def truncate(text: str, length: int, overflow: str = "...", pad: bool = True) -> str:
"""
Truncate text to specified length with optional padding and overflow indicator.
Args:
text (str): Text to truncate
length (int): Maximum length including overflow indicator
overflow (str): String to indicate truncation (defaults to '...')
pad (bool): Whether to pad the result to full length (defaults to padding)
Returns:
Truncated string, padded if requested
"""
if len(text) <= length:
return text + (" " * (length - len(text))) if pad else text

overflow_length = len(overflow)
truncated = text[: length - overflow_length] + overflow

return truncated
21 changes: 21 additions & 0 deletions tests/util/test_truncate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from inspect_ai._util.text import truncate


def test_basic_truncation():
assert truncate("Hello World!", 8) == "Hello..."
assert truncate("Short", 8) == "Short "


def test_custom_overflow():
assert truncate("Hello World!", 8, overflow=">>") == "Hello >>"
assert truncate("Testing", 5, overflow="~") == "Test~"


def test_no_padding():
assert truncate("Hi", 8, pad=False) == "Hi"
assert truncate("Hello World!", 8, pad=False) == "Hello..."


def test_exact_length():
assert truncate("12345678", 8) == "12345678"
assert truncate("1234", 4, pad=False) == "1234"

0 comments on commit ac268d1

Please sign in to comment.