diff --git a/python/sglang/srt/debug_utils/comparator/display.py b/python/sglang/srt/debug_utils/comparator/display.py index ba4179c0bc6f..32ae5fbc396e 100644 --- a/python/sglang/srt/debug_utils/comparator/display.py +++ b/python/sglang/srt/debug_utils/comparator/display.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional import polars as pl +import rich.table if TYPE_CHECKING: from rich.table import Table @@ -68,6 +69,19 @@ def _build_rich_table(df: pl.DataFrame, *, title: Optional[str] = None) -> "Tabl return table +def _render_polars_as_rich_table( + df: pl.DataFrame, *, title: Optional[str] = None +) -> "rich.table.Table": + from rich.table import Table + + table = Table(title=title) + for col in df.columns: + table.add_column(col) + for row in df.iter_rows(): + table.add_row(*[str(v) for v in row]) + return table + + def _collect_rank_info( df: pl.DataFrame, dump_dir: Path ) -> Optional[list[dict[str, Any]]]: diff --git a/python/sglang/srt/debug_utils/comparator/output_formatter.py b/python/sglang/srt/debug_utils/comparator/output_formatter.py index dfa25368fa74..cd93a3d2c726 100644 --- a/python/sglang/srt/debug_utils/comparator/output_formatter.py +++ b/python/sglang/srt/debug_utils/comparator/output_formatter.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from rich.console import Group from rich.markup import escape @@ -38,12 +38,16 @@ _TableRecord, ) +Verbosity = Literal["minimal", "normal", "verbose"] + # ── Record-level rendering (body + logs) ───────────────────────────── -def _render_record_rich(record: _OutputRecord) -> RenderableType: - body: RenderableType = record._format_rich_body() +def _render_record_rich( + record: _OutputRecord, *, verbosity: Verbosity = "normal" +) -> RenderableType: + body: RenderableType = record._format_rich_body(verbosity=verbosity) log_lines: list[str] = _format_log_lines_rich( errors=record.errors, infos=record.infos @@ -100,7 +104,9 @@ def _format_config_body(record: ConfigRecord) -> str: return f"Config: {record.config}" -def _format_config_rich_body(record: ConfigRecord) -> RenderableType: +def _format_config_rich_body( + record: ConfigRecord, verbosity: Verbosity = "normal" +) -> RenderableType: lines: list[str] = [f" [bold]{k}[/] : {v}" for k, v in record.config.items()] return Panel("\n".join(lines), title="Comparator Config", border_style="cyan") @@ -112,7 +118,9 @@ def _format_skip_body(record: SkipComparisonRecord) -> str: return f"Skip: {record.name}{record._format_location_suffix()} ({record.reason})" -def _format_skip_rich_body(record: SkipComparisonRecord) -> RenderableType: +def _format_skip_rich_body( + record: SkipComparisonRecord, verbosity: Verbosity = "normal" +) -> RenderableType: suffix: str = record._format_location_suffix() return ( f"[dim]⊘ {escape(record.name)}{suffix} ── skipped ({escape(record.reason)})[/]" @@ -132,7 +140,9 @@ def _format_table_body(record: _TableRecord) -> str: ) -def _format_table_rich_body(record: _TableRecord) -> RenderableType: +def _format_table_rich_body( + record: _TableRecord, verbosity: Verbosity = "normal" +) -> RenderableType: import polars as pl from sglang.srt.debug_utils.comparator.display import ( @@ -157,13 +167,15 @@ def _format_tensor_comparison_body(record: TensorComparisonRecord) -> str: def _format_tensor_comparison_rich_body( - record: TensorComparisonRecord, + record: TensorComparisonRecord, verbosity: Verbosity = "normal" ) -> RenderableType: from sglang.srt.debug_utils.comparator.tensor_comparator.formatter import ( format_comparison_rich, ) - return record._format_location_prefix_rich() + format_comparison_rich(record=record) + return record._format_location_prefix_rich() + format_comparison_rich( + record=record, verbosity=verbosity + ) # ── NonTensorComparisonRecord ──────────────────────────────────────── @@ -181,7 +193,7 @@ def _format_non_tensor_body(record: NonTensorComparisonRecord) -> str: def _format_non_tensor_rich_body( - record: NonTensorComparisonRecord, + record: NonTensorComparisonRecord, verbosity: Verbosity = "normal" ) -> RenderableType: suffix: str = record._format_location_suffix() name: str = escape(record.name) @@ -210,7 +222,9 @@ def _format_summary_body(record: SummaryRecord) -> str: ) -def _format_summary_rich_body(record: SummaryRecord) -> RenderableType: +def _format_summary_rich_body( + record: SummaryRecord, verbosity: Verbosity = "normal" +) -> RenderableType: text: str = ( f"[bold green]{record.passed} passed[/] │ " f"[bold red]{record.failed} failed[/] │ " diff --git a/python/sglang/srt/debug_utils/comparator/output_types.py b/python/sglang/srt/debug_utils/comparator/output_types.py index db3722986c0b..e07281e7e895 100644 --- a/python/sglang/srt/debug_utils/comparator/output_types.py +++ b/python/sglang/srt/debug_utils/comparator/output_types.py @@ -7,6 +7,26 @@ from rich.console import RenderableType from rich.markup import escape +from sglang.srt.debug_utils.comparator.output_formatter import ( # noqa: F401 — re-export + _format_aligner_plan as _format_aligner_plan, +) +from sglang.srt.debug_utils.comparator.output_formatter import ( + _format_config_body, + _format_config_rich_body, + _format_log_body, + _format_non_tensor_body, + _format_non_tensor_rich_body, + _format_skip_body, + _format_skip_rich_body, + _format_summary_body, + _format_summary_rich_body, + _format_table_body, + _format_table_rich_body, + _format_tensor_comparison_body, + _format_tensor_comparison_rich_body, + _render_record_rich, + _render_record_text, +) from sglang.srt.debug_utils.comparator.tensor_comparator.types import ( DiffInfo, TensorComparisonInfo, @@ -17,7 +37,6 @@ from sglang.srt.debug_utils.comparator.aligner.entrypoint.traced_types import ( TracedAlignerPlan, ) - from sglang.srt.debug_utils.comparator.aligner.entrypoint.types import AlignerPlan from sglang.srt.debug_utils.comparator.report_sink import Verbosity @@ -83,21 +102,13 @@ class _OutputRecord(_StrictBase): @abstractmethod def _format_body(self) -> str: ... - def _format_rich_body(self) -> RenderableType: + def _format_rich_body(self, verbosity: Verbosity = "normal") -> RenderableType: return self._format_body() - def to_rich(self) -> RenderableType: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _render_record_rich, - ) - - return _render_record_rich(self) + def to_rich(self, verbosity: Verbosity = "normal") -> RenderableType: + return _render_record_rich(self, verbosity=verbosity) def to_text(self) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _render_record_text, - ) - return _render_record_text(self) @@ -124,46 +135,15 @@ def _format_location_suffix(self) -> str: return "" -class RecordLocation(_StrictBase): - step: Optional[int] = None - - -class _BaseComparisonRecord(_OutputRecord): - location: RecordLocation = Field(default_factory=RecordLocation) - - def _format_location_prefix(self) -> str: - if self.location.step is not None: - return f"[step={self.location.step}] " - return "" - - def _format_location_prefix_rich(self) -> str: - if self.location.step is not None: - return escape(f"[step={self.location.step}]") + " " - return "" - - def _format_location_suffix(self) -> str: - if self.location.step is not None: - return f" (step={self.location.step})" - return "" - - class ConfigRecord(_OutputRecord): type: Literal["config"] = "config" config: dict[str, Any] def _format_body(self) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_config_body, - ) - return _format_config_body(self) - def _format_rich_body(self) -> RenderableType: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_config_rich_body, - ) - - return _format_config_rich_body(self) + def _format_rich_body(self, verbosity: Verbosity = "normal") -> RenderableType: + return _format_config_rich_body(self, verbosity=verbosity) class SkipComparisonRecord(_BaseComparisonRecord): @@ -178,18 +158,10 @@ def category(self) -> str: return "skipped" def _format_body(self) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_skip_body, - ) - return _format_skip_body(self) - def _format_rich_body(self) -> RenderableType: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_skip_rich_body, - ) - - return _format_skip_rich_body(self) + def _format_rich_body(self, verbosity: Verbosity = "normal") -> RenderableType: + return _format_skip_rich_body(self, verbosity=verbosity) class _TableRecord(_OutputRecord): @@ -200,9 +172,10 @@ class _TableRecord(_OutputRecord): def _table_title(self) -> str: ... def _format_body(self) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_table_body, - ) + return _format_table_body(self) + + def _format_rich_body(self, verbosity: Verbosity = "normal") -> RenderableType: + return _format_table_rich_body(self, verbosity=verbosity) return _format_table_body(self) @@ -245,18 +218,10 @@ def category(self) -> str: return "passed" if self.diff is not None and self.diff.passed else "failed" def _format_body(self) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_tensor_comparison_body, - ) - return _format_tensor_comparison_body(self) - def _format_rich_body(self) -> RenderableType: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_tensor_comparison_rich_body, - ) - - return _format_tensor_comparison_rich_body(self) + def _format_rich_body(self, verbosity: Verbosity = "normal") -> RenderableType: + return _format_tensor_comparison_rich_body(self, verbosity=verbosity) class NonTensorComparisonRecord(_BaseComparisonRecord): @@ -275,9 +240,10 @@ def category(self) -> str: return "passed" if self.values_equal else "failed" def _format_body(self) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_non_tensor_body, - ) + return _format_non_tensor_body(self) + + def _format_rich_body(self, verbosity: Verbosity = "normal") -> RenderableType: + return _format_non_tensor_rich_body(self, verbosity=verbosity) return _format_non_tensor_body(self) @@ -306,40 +272,20 @@ def _validate_totals(self) -> "SummaryRecord": return self def _format_body(self) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_summary_body, - ) - return _format_summary_body(self) - def _format_rich_body(self) -> RenderableType: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_summary_rich_body, - ) - - return _format_summary_rich_body(self) + def _format_rich_body(self, verbosity: Verbosity = "normal") -> RenderableType: + return _format_summary_rich_body(self, verbosity=verbosity) + return _format_summary_body(self) class LogRecord(_OutputRecord): type: Literal["log"] = "log" def _format_body(self) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_log_body, - ) - return _format_log_body(self) -# Re-export _format_aligner_plan for backward compatibility (used by tests) -def _format_aligner_plan(traced_plan: TracedAlignerPlan) -> str: - from sglang.srt.debug_utils.comparator.output_formatter import ( - _format_aligner_plan as _impl, - ) - - return _impl(traced_plan) - - AnyRecord = Annotated[ Union[ ConfigRecord, diff --git a/python/sglang/srt/debug_utils/comparator/report_sink.py b/python/sglang/srt/debug_utils/comparator/report_sink.py index cab8f5cd81e1..61f9e9ac55ce 100644 --- a/python/sglang/srt/debug_utils/comparator/report_sink.py +++ b/python/sglang/srt/debug_utils/comparator/report_sink.py @@ -80,7 +80,7 @@ def _print_to_stdout(self, record: _OutputRecord) -> None: print(record.model_dump_json()) else: console: Console = self._get_console() - console.print(record.to_rich()) + console.print(record.to_rich(verbosity=self._verbosity)) console.print() # blank line between records diff --git a/python/sglang/srt/debug_utils/comparator/tensor_comparator/formatter.py b/python/sglang/srt/debug_utils/comparator/tensor_comparator/formatter.py index ece8fd21416a..47dd09efbb20 100644 --- a/python/sglang/srt/debug_utils/comparator/tensor_comparator/formatter.py +++ b/python/sglang/srt/debug_utils/comparator/tensor_comparator/formatter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from rich.markup import escape @@ -26,6 +26,8 @@ ) from sglang.srt.debug_utils.comparator.utils import Pair +Verbosity = Literal["minimal", "normal", "verbose"] + def _esc_shape(shape: Optional[list[int]]) -> str: return escape(str(shape)) @@ -64,7 +66,7 @@ def _category_marker(category: str) -> tuple[bool, str, str]: # --------------------------------------------------------------------------- -# Stats formatting helpers +# Stats formatting helpers (shared between compact / verbose) # --------------------------------------------------------------------------- @@ -204,11 +206,36 @@ def _format_diff(diff: DiffInfo, prefix_text: str = "") -> list[str]: # --------------------------------------------------------------------------- -def format_comparison_rich(record: TensorComparisonRecord) -> str: - return _format_comparison_normal(record=record) +def format_comparison_rich( + record: TensorComparisonRecord, + verbosity: Verbosity = "normal", +) -> str: + if verbosity == "minimal": + return _format_comparison_minimal(record) + + return _format_comparison_normal_or_verbose( + record=record, + verbose=(verbosity == "verbose"), + ) -def _format_comparison_normal(*, record: TensorComparisonRecord) -> str: +def _format_comparison_minimal(record: TensorComparisonRecord) -> str: + passed, color, marker = _category_marker(record.category) + + name_part: str = f"[bold {color}]{escape(record.name):30s}[/]" + if record.diff is not None: + return f"{marker} {name_part} rel_diff={_fmt_val(record.diff.rel_diff)}" + elif record.shape_mismatch: + return f"{marker} {name_part} [yellow]shape mismatch[/]" + else: + return f"{marker} {name_part}" + + +def _format_comparison_normal_or_verbose( + *, + record: TensorComparisonRecord, + verbose: bool, +) -> str: passed, color, marker = _category_marker(record.category) baseline: TensorInfo = record.baseline @@ -254,12 +281,19 @@ def _format_comparison_normal(*, record: TensorComparisonRecord) -> str: # Bundle section if record.raw_bundle_info is not None: lines.append(" [dim]Bundle[/]") - lines.extend(_format_bundle_section(bundle_info=record.raw_bundle_info)) + lines.extend( + _format_bundle_section(bundle_info=record.raw_bundle_info, verbose=verbose) + ) # Plan section if record.traced_plan is not None: lines.append(" [dim]Plan[/]") - lines.extend(_format_plan_section_rich(traced_plan=record.traced_plan)) + lines.extend( + _format_plan_section_rich( + traced_plan=record.traced_plan, + verbose=verbose, + ) + ) # Aligned section lines.append(" [dim]Aligned[/]") @@ -270,22 +304,28 @@ def _format_comparison_normal(*, record: TensorComparisonRecord) -> str: # Stats section lines.append(" [dim]Stats[/]") - lines.extend(_format_stats_rich(baseline=baseline.stats, target=target.stats)) + lines.extend( + _format_stats_rich( + baseline=baseline.stats, target=target.stats, verbose=verbose + ) + ) - # Abs diff percentiles (show when failed) - if not passed and record.diff is not None and record.diff.abs_diff_percentiles: + show_detail: bool = verbose or not passed + + # Abs diff percentiles + if show_detail and record.diff is not None and record.diff.abs_diff_percentiles: lines.append(" [dim]Abs Diff Percentiles[/]") lines.append(" " + _format_abs_diff_percentiles_rich(record.diff)) - # Samples (show when failed) - if not passed and baseline.sample is not None: + # Samples + if show_detail and baseline.sample is not None: lines.append(" [dim]Samples[/]") lines.append(f" baseline {escape(baseline.sample)}") if target.sample is not None: lines.append(f" target {escape(target.sample)}") - # Replicated checks (show when failed) - if not passed and record.replicated_checks: + # Replicated checks + if show_detail and record.replicated_checks: lines.append(" [dim]Replicated Checks[/]") for check in record.replicated_checks: chk_marker: str = "[green]✅[/]" if check.passed else "[red]❌[/]" @@ -305,7 +345,9 @@ def _format_comparison_normal(*, record: TensorComparisonRecord) -> str: return "\n".join(lines) -def _format_bundle_section(bundle_info: Pair[BundleSideInfo]) -> list[str]: +def _format_bundle_section( + bundle_info: Pair[BundleSideInfo], *, verbose: bool = False +) -> list[str]: lines: list[str] = [] for label, side in [("baseline", bundle_info.x), ("target", bundle_info.y)]: @@ -315,19 +357,37 @@ def _format_bundle_section(bundle_info: Pair[BundleSideInfo]) -> list[str]: dtype_desc: str = _strip_torch_prefix(side.files[0].dtype) - shapes: list[list[int]] = [f.shape for f in side.files] - unique_shapes: set[str] = {str(s) for s in shapes} - shape_desc: str - if len(unique_shapes) == 1: - shape_desc = _esc_shape(shapes[0]) + if verbose: + dims_part: str = f" dims: {side.dims}" if side.dims else "" + lines.append( + f" {label} [cyan]{side.num_files} files[/]" + f" {dtype_desc}{dims_part}" + ) + + for idx, f in enumerate(side.files): + rank_part: str = f"rank={f.rank}" if f.rank is not None else "" + par_part: str = "" + if f.parallel_info: + par_part = " " + " ".join( + f"{k}={v}" for k, v in f.parallel_info.items() + ) + lines.append( + f" [{idx}] {_esc_shape(f.shape)} {rank_part}{par_part}" + ) else: - shape_desc = "mixed shapes" + shapes: list[list[int]] = [f.shape for f in side.files] + unique_shapes: set[str] = {str(s) for s in shapes} + shape_desc: str + if len(unique_shapes) == 1: + shape_desc = _esc_shape(shapes[0]) + else: + shape_desc = "mixed shapes" - dims_part: str = f" [dim]dims: {side.dims}[/]" if side.dims else "" - lines.append( - f" {label} [cyan]{side.num_files} files[/]" - f" × {shape_desc} {dtype_desc}{dims_part}" - ) + dims_part = f" [dim]dims: {side.dims}[/]" if side.dims else "" + lines.append( + f" {label} [cyan]{side.num_files} files[/]" + f" × {shape_desc} {dtype_desc}{dims_part}" + ) return lines @@ -335,6 +395,7 @@ def _format_bundle_section(bundle_info: Pair[BundleSideInfo]) -> list[str]: def _format_plan_section_rich( *, traced_plan: TracedAlignerPlan, + verbose: bool = False, ) -> list[str]: lines: list[str] = [] @@ -406,19 +467,35 @@ def _format_stats_rich( *, baseline: TensorStats, target: TensorStats, + verbose: bool = False, ) -> list[str]: lines: list[str] = [] - # Compact: mean, std, range (min/max combined) - for stat_name in ("mean", "std"): - val_b: float = getattr(baseline, stat_name) - val_t: float = getattr(target, stat_name) - lines.append(_format_stat_line(stat_name, val_b, val_t, val_t - val_b)) - - # Range line: combine min/max (escape brackets to avoid Rich markup) - range_baseline: str = escape(f"[{baseline.min:.4f}, {baseline.max:.4f}]") - range_target: str = escape(f"[{target.min:.4f}, {target.max:.4f}]") - lines.append(f" [blue]{'range':10s}[/] {range_baseline} vs {range_target}") + if verbose: + # All stat fields + for stat_name in TensorStats.model_fields: + if stat_name == "percentiles": + continue + val_b: float = getattr(baseline, stat_name) + val_t: float = getattr(target, stat_name) + lines.append(_format_stat_line(stat_name, val_b, val_t, val_t - val_b)) + + # Percentiles + for p in sorted(set(baseline.percentiles) & set(target.percentiles)): + val_b = baseline.percentiles[p] + val_t = target.percentiles[p] + lines.append(_format_stat_line(f"p{p}", val_b, val_t, val_t - val_b)) + else: + # Compact: mean, std, range (min/max combined) + for stat_name in ("mean", "std"): + val_b = getattr(baseline, stat_name) + val_t = getattr(target, stat_name) + lines.append(_format_stat_line(stat_name, val_b, val_t, val_t - val_b)) + + # Range line: combine min/max (escape brackets to avoid Rich markup) + range_baseline: str = escape(f"[{baseline.min:.4f}, {baseline.max:.4f}]") + range_target: str = escape(f"[{target.min:.4f}, {target.max:.4f}]") + lines.append(f" [blue]{'range':10s}[/] {range_baseline} vs {range_target}") return lines diff --git a/test/registered/debug_utils/comparator/tensor_comparator/test_formatter.py b/test/registered/debug_utils/comparator/tensor_comparator/test_formatter.py index a9f8a87afe75..24329bf1c987 100644 --- a/test/registered/debug_utils/comparator/tensor_comparator/test_formatter.py +++ b/test/registered/debug_utils/comparator/tensor_comparator/test_formatter.py @@ -409,18 +409,61 @@ def _make_traced_plan( # --------------------------------------------------------------------------- -# Rich format snapshot tests (normal mode only) +# Rich format snapshot tests # --------------------------------------------------------------------------- +class TestFormatComparisonRichMinimal: + """format_comparison_rich() with verbosity='minimal'.""" + + def test_passed(self) -> None: + record: TensorComparisonRecord = _make_comparison_record( + diff=_make_diff(rel_diff=1e-4, passed=True), + ) + result: str = format_comparison_rich(record, verbosity="minimal") + + assert result == ( + "[green]✅[/] [bold green]hidden_states [/] " + "rel_diff=1.00e-04" + ) + + def test_failed(self) -> None: + record: TensorComparisonRecord = _make_comparison_record( + diff=_make_diff(rel_diff=0.5, passed=False), + ) + result: str = format_comparison_rich(record, verbosity="minimal") + + assert result == ( + "[red]❌[/] [bold red]hidden_states [/] " + "rel_diff=5.00e-01" + ) + + def test_shape_mismatch(self) -> None: + record: TensorComparisonRecord = _make_comparison_record( + shape_mismatch=True, + ) + result: str = format_comparison_rich(record, verbosity="minimal") + + assert result == ( + "[red]❌[/] [bold red]hidden_states [/] " + "[yellow]shape mismatch[/]" + ) + + def test_no_diff(self) -> None: + record: TensorComparisonRecord = _make_comparison_record() + result: str = format_comparison_rich(record, verbosity="minimal") + + assert result == ("[red]❌[/] [bold red]hidden_states [/]") + + class TestFormatComparisonRichNormal: - """format_comparison_rich() snapshot tests.""" + """format_comparison_rich() with verbosity='normal'.""" def test_passed(self) -> None: record: TensorComparisonRecord = _make_comparison_record( diff=_make_diff(rel_diff=1e-4, passed=True), ) - result: str = format_comparison_rich(record) + result: str = format_comparison_rich(record, verbosity="normal") assert result == ( "[green]✅[/] [bold green]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" @@ -439,7 +482,7 @@ def test_failed(self) -> None: rel_diff=0.5, max_abs_diff=1.0, mean_abs_diff=0.3, passed=False ), ) - result: str = format_comparison_rich(record) + result: str = format_comparison_rich(record, verbosity="normal") assert result == ( "[red]❌[/] [bold red]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" @@ -459,7 +502,7 @@ def test_shape_mismatch(self) -> None: record: TensorComparisonRecord = _make_comparison_record( shape_mismatch=True, ) - result: str = format_comparison_rich(record) + result: str = format_comparison_rich(record, verbosity="normal") assert result == ( "[red]❌[/] [bold red]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" @@ -478,7 +521,7 @@ def test_with_downcast(self) -> None: diff_downcast=_make_diff(rel_diff=1e-5, passed=True), downcast_dtype="torch.bfloat16", ) - result: str = format_comparison_rich(record) + result: str = format_comparison_rich(record, verbosity="normal") assert result == ( "[red]❌[/] [bold red]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" @@ -504,7 +547,7 @@ def test_with_bundle_info(self) -> None: diff=_make_diff(passed=True), raw_bundle_info=bundle_info, ) - result: str = format_comparison_rich(record) + result: str = format_comparison_rich(record, verbosity="normal") assert result == ( "[green]✅[/] [bold green]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" @@ -526,7 +569,7 @@ def test_with_plan(self) -> None: diff=_make_diff(passed=True), traced_plan=_make_traced_plan(plan), ) - result: str = format_comparison_rich(record) + result: str = format_comparison_rich(record, verbosity="normal") assert result == ( "[green]✅[/] [bold green]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" @@ -543,6 +586,113 @@ def test_with_plan(self) -> None: ) +class TestFormatComparisonRichVerbose: + """format_comparison_rich() with verbosity='verbose'.""" + + def test_passed_full_detail(self) -> None: + record: TensorComparisonRecord = _make_comparison_record( + diff=_make_diff(rel_diff=1e-4, passed=True), + sample="tensor([0.1, 0.2, ...])", + ) + result: str = format_comparison_rich(record, verbosity="verbose") + + assert result == ( + "[green]✅[/] [bold green]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" + " [green]rel_diff=1.00e-04[/] max_abs=5.00e-04 mean_abs=2.00e-04\n" + " [dim]Aligned[/]\n" + " [4, 8] vs [4, 8] torch.float32 vs torch.float32\n" + " [dim]Stats[/]\n" + " [blue]mean [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]abs_mean [/] 0.8000 vs 0.8000 Δ [dim]+0.00e+00[/]\n" + " [blue]std [/] 1.0000 vs 1.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]min [/] -2.0000 vs -2.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]max [/] 2.0000 vs 2.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]p1 [/] -1.8000 vs -1.8000 Δ [dim]+0.00e+00[/]\n" + " [blue]p5 [/] -1.5000 vs -1.5000 Δ [dim]+0.00e+00[/]\n" + " [blue]p50 [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]p95 [/] 1.5000 vs 1.5000 Δ [dim]+0.00e+00[/]\n" + " [blue]p99 [/] 1.8000 vs 1.8000 Δ [dim]+0.00e+00[/]\n" + " [dim]Abs Diff Percentiles[/]\n" + " p1=1.00e-04 p5=1.00e-04 p50=2.00e-04 p95=4.00e-04 p99=5.00e-04\n" + " [dim]Samples[/]\n" + " baseline tensor([0.1, 0.2, ...])\n" + " target tensor([0.1, 0.2, ...])" + ) + + def test_with_bundle_verbose(self) -> None: + bundle_info: Pair[BundleSideInfo] = Pair( + x=_make_bundle_side_info(num_files=2, with_parallel_info=True), + y=_make_bundle_side_info(num_files=2, with_parallel_info=True), + ) + record: TensorComparisonRecord = _make_comparison_record( + diff=_make_diff(passed=True), + raw_bundle_info=bundle_info, + ) + result: str = format_comparison_rich(record, verbosity="verbose") + + assert result == ( + "[green]✅[/] [bold green]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" + " [green]rel_diff=1.00e-04[/] max_abs=5.00e-04 mean_abs=2.00e-04\n" + " [dim]Bundle[/]\n" + " baseline [cyan]2 files[/] float32\n" + " [0] [2, 4096] rank=0 tp=0/2\n" + " [1] [2, 4096] rank=1 tp=1/2\n" + " target [cyan]2 files[/] float32\n" + " [0] [2, 4096] rank=0 tp=0/2\n" + " [1] [2, 4096] rank=1 tp=1/2\n" + " [dim]Aligned[/]\n" + " [4, 8] vs [4, 8] torch.float32 vs torch.float32\n" + " [dim]Stats[/]\n" + " [blue]mean [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]abs_mean [/] 0.8000 vs 0.8000 Δ [dim]+0.00e+00[/]\n" + " [blue]std [/] 1.0000 vs 1.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]min [/] -2.0000 vs -2.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]max [/] 2.0000 vs 2.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]p1 [/] -1.8000 vs -1.8000 Δ [dim]+0.00e+00[/]\n" + " [blue]p5 [/] -1.5000 vs -1.5000 Δ [dim]+0.00e+00[/]\n" + " [blue]p50 [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]p95 [/] 1.5000 vs 1.5000 Δ [dim]+0.00e+00[/]\n" + " [blue]p99 [/] 1.8000 vs 1.8000 Δ [dim]+0.00e+00[/]\n" + " [dim]Abs Diff Percentiles[/]\n" + " p1=1.00e-04 p5=1.00e-04 p50=2.00e-04 p95=4.00e-04 p99=5.00e-04" + ) + + def test_with_plan_and_traces(self) -> None: + plan: AlignerPlan = _make_simple_aligner_plan(with_unsharder=True) + record: TensorComparisonRecord = _make_comparison_record( + diff=_make_diff(passed=True), + traced_plan=_make_traced_plan( + plan, + target_input_shapes=[[2, 4096], [2, 4096]], + target_output_shapes=[[4, 4096]], + ), + ) + result: str = format_comparison_rich(record, verbosity="verbose") + + assert result == ( + "[green]✅[/] [bold green]hidden_states[/] [dim cyan]── float32 [4, 8][/]\n" + " [green]rel_diff=1.00e-04[/] max_abs=5.00e-04 mean_abs=2.00e-04\n" + " [dim]Plan[/]\n" + " baseline [dim](passthrough)[/]\n" + " target [magenta]unsharder(ParallelAxis.TP)[/] 2×[2, 4096] → 1×[4, 4096]\n" + " [dim]Aligned[/]\n" + " [4, 8] vs [4, 8] torch.float32 vs torch.float32\n" + " [dim]Stats[/]\n" + " [blue]mean [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]abs_mean [/] 0.8000 vs 0.8000 Δ [dim]+0.00e+00[/]\n" + " [blue]std [/] 1.0000 vs 1.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]min [/] -2.0000 vs -2.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]max [/] 2.0000 vs 2.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]p1 [/] -1.8000 vs -1.8000 Δ [dim]+0.00e+00[/]\n" + " [blue]p5 [/] -1.5000 vs -1.5000 Δ [dim]+0.00e+00[/]\n" + " [blue]p50 [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]\n" + " [blue]p95 [/] 1.5000 vs 1.5000 Δ [dim]+0.00e+00[/]\n" + " [blue]p99 [/] 1.8000 vs 1.8000 Δ [dim]+0.00e+00[/]\n" + " [dim]Abs Diff Percentiles[/]\n" + " p1=1.00e-04 p5=1.00e-04 p50=2.00e-04 p95=4.00e-04 p99=5.00e-04" + ) + + class TestFormatBundleSection: """_format_bundle_section() snapshot tests.""" @@ -597,6 +747,36 @@ def test_with_dims(self) -> None: ] +class TestFormatBundleSectionVerbose: + """_format_bundle_section(verbose=True) snapshot tests.""" + + def test_per_file_listing(self) -> None: + bundle: Pair[BundleSideInfo] = Pair( + x=_make_bundle_side_info(num_files=2, with_parallel_info=True), + y=_make_bundle_side_info(num_files=2, with_parallel_info=True), + ) + lines: list[str] = _format_bundle_section(bundle, verbose=True) + + assert lines == [ + " baseline [cyan]2 files[/] float32", + " [0] [2, 4096] rank=0 tp=0/2", + " [1] [2, 4096] rank=1 tp=1/2", + " target [cyan]2 files[/] float32", + " [0] [2, 4096] rank=0 tp=0/2", + " [1] [2, 4096] rank=1 tp=1/2", + ] + + def test_no_files(self) -> None: + empty: BundleSideInfo = BundleSideInfo(num_files=0, files=[]) + bundle: Pair[BundleSideInfo] = Pair(x=empty, y=empty) + lines: list[str] = _format_bundle_section(bundle, verbose=True) + + assert lines == [ + " baseline [dim](no files)[/]", + " target [dim](no files)[/]", + ] + + class TestFormatPlanSectionRich: """_format_plan_section_rich() snapshot tests.""" @@ -719,6 +899,45 @@ def test_small_delta(self) -> None: ] +class TestFormatStatsRichVerbose: + """_format_stats_rich(verbose=True) snapshot tests.""" + + def test_all_stats_with_percentiles(self) -> None: + baseline: TensorStats = _make_stats() + target: TensorStats = _make_stats() + lines: list[str] = _format_stats_rich( + baseline=baseline, target=target, verbose=True + ) + + assert lines == [ + " [blue]mean [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]", + " [blue]abs_mean [/] 0.8000 vs 0.8000 Δ [dim]+0.00e+00[/]", + " [blue]std [/] 1.0000 vs 1.0000 Δ [dim]+0.00e+00[/]", + " [blue]min [/] -2.0000 vs -2.0000 Δ [dim]+0.00e+00[/]", + " [blue]max [/] 2.0000 vs 2.0000 Δ [dim]+0.00e+00[/]", + " [blue]p1 [/] -1.8000 vs -1.8000 Δ [dim]+0.00e+00[/]", + " [blue]p5 [/] -1.5000 vs -1.5000 Δ [dim]+0.00e+00[/]", + " [blue]p50 [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]", + " [blue]p95 [/] 1.5000 vs 1.5000 Δ [dim]+0.00e+00[/]", + " [blue]p99 [/] 1.8000 vs 1.8000 Δ [dim]+0.00e+00[/]", + ] + + def test_no_percentiles(self) -> None: + baseline: TensorStats = _make_stats(percentiles={}) + target: TensorStats = _make_stats(percentiles={}) + lines: list[str] = _format_stats_rich( + baseline=baseline, target=target, verbose=True + ) + + assert lines == [ + " [blue]mean [/] 0.0000 vs 0.0000 Δ [dim]+0.00e+00[/]", + " [blue]abs_mean [/] 0.8000 vs 0.8000 Δ [dim]+0.00e+00[/]", + " [blue]std [/] 1.0000 vs 1.0000 Δ [dim]+0.00e+00[/]", + " [blue]min [/] -2.0000 vs -2.0000 Δ [dim]+0.00e+00[/]", + " [blue]max [/] 2.0000 vs 2.0000 Δ [dim]+0.00e+00[/]", + ] + + class TestFormatAbsDiffPercentilesRich: """_format_abs_diff_percentiles_rich() snapshot tests.""" diff --git a/test/registered/debug_utils/comparator/test_entrypoint.py b/test/registered/debug_utils/comparator/test_entrypoint.py index 964727d5e239..41e5d5ae849a 100644 --- a/test/registered/debug_utils/comparator/test_entrypoint.py +++ b/test/registered/debug_utils/comparator/test_entrypoint.py @@ -294,8 +294,8 @@ def test_text_output_smoke(self, tmp_path, capsys): run(parse_args(argv)) output = capsys.readouterr().out - assert "Config:" in output - assert "Summary:" in output + assert "Comparator Config" in output + assert "SUMMARY" in output def test_text_output_with_failure(self, tmp_path, capsys): """Text output with a failed comparison renders failure info.""" @@ -317,7 +317,7 @@ def test_text_output_with_failure(self, tmp_path, capsys): run(parse_args(argv)) output = capsys.readouterr().out - assert "Summary:" in output + assert "SUMMARY" in output assert "failed" in output.lower() def test_duplicate_dump_pairing(self, tmp_path, capsys): @@ -4267,102 +4267,6 @@ def test_streaming_flush(self, tmp_path, capsys): assert isinstance(parsed, ConfigRecord) -class TestEntrypointAutoDescend: - """Test auto-descend: --baseline-path / --target-path pointing to a parent - directory that contains a single subdirectory with .pt files.""" - - def test_auto_descend_single_engine(self, tmp_path: Path, capsys) -> None: - """Parent dir wrapping a single engine subdir is auto-descended and comparison succeeds.""" - baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) - - baseline_wrapper: Path = tmp_path / "baseline_wrap" - target_wrapper: Path = tmp_path / "target_wrap" - baseline_wrapper.mkdir() - target_wrapper.mkdir() - baseline_exp.rename(baseline_wrapper / "engine_0") - target_exp.rename(target_wrapper / "engine_0") - - argv = _make_argv(baseline_wrapper, target_wrapper, preset="raw") - records, exit_code = _run_and_parse(argv, capsys) - - assert exit_code == 0 - _assert_single_comparison_passed(records) - - def test_no_descend_when_pt_at_root(self, tmp_path: Path, capsys) -> None: - """Direct .pt files — no descend needed, comparison still works.""" - baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) - - argv = _make_argv(baseline_exp, target_exp, preset="raw") - records, exit_code = _run_and_parse(argv, capsys) - - assert exit_code == 0 - _assert_single_comparison_passed(records) - - def test_auto_descend_emits_log_record(self, tmp_path: Path, capsys) -> None: - """Auto-descend emits a LogRecord with the info message.""" - baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) - - wrapper: Path = tmp_path / "target_wrap" - wrapper.mkdir() - target_exp.rename(wrapper / "engine_0") - - argv = _make_argv(baseline_exp, wrapper, preset="raw") - records, _ = _run_and_parse(argv, capsys) - - log_records: list[LogRecord] = [r for r in records if isinstance(r, LogRecord)] - auto_descend_msgs: list[str] = [ - info.message - for lr in log_records - for info in lr.infos - if "auto-descend" in info.message - ] - assert any("target_path" in m for m in auto_descend_msgs) - - def test_auto_descend_single_nonempty_among_empty( - self, tmp_path: Path, capsys - ) -> None: - """Two subdirs but only one has .pt — auto-descend picks the non-empty one.""" - baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) - - wrapper: Path = tmp_path / "target_wrap" - wrapper.mkdir() - target_exp.rename(wrapper / "engine_0") - (wrapper / "empty_subdir").mkdir() - - argv = _make_argv(baseline_exp, wrapper, preset="raw") - records, exit_code = _run_and_parse(argv, capsys) - - assert exit_code == 0 - _assert_single_comparison_passed(records) - - def test_error_multiple_nonempty_subdirs(self, tmp_path: Path) -> None: - """Two subdirs both with .pt — raises ValueError with clear message.""" - baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) - - wrapper: Path = tmp_path / "target_wrap" - wrapper.mkdir() - target_exp.rename(wrapper / "engine_0") - engine_1: Path = wrapper / "engine_1" - engine_1.mkdir() - torch.save(torch.tensor([1.0]), engine_1 / "dummy.pt") - - argv: list[str] = _make_argv(baseline_exp, wrapper, preset="raw") - with pytest.raises(ValueError, match="multiple subdirectories contain data"): - run(parse_args(argv)) - - def test_error_no_data_found(self, tmp_path: Path) -> None: - """No .pt files anywhere — raises ValueError.""" - baseline_exp, _ = _create_dumps(tmp_path, ["tensor_a"]) - - empty_dir: Path = tmp_path / "empty_target" - empty_dir.mkdir() - (empty_dir / "subdir").mkdir() - - argv: list[str] = _make_argv(baseline_exp, empty_dir, preset="raw") - with pytest.raises(ValueError, match="no .pt files found"): - run(parse_args(argv)) - - class TestEntrypointDpAttentionMissingAlias: """Regression: dp-attention without ``# dp:=attn_dp`` → shape mismatch failure. @@ -4460,5 +4364,101 @@ def test_missing_dp_alias_causes_shape_mismatch( assert comparison.category == "failed" +class TestEntrypointAutoDescend: + """Test auto-descend: --baseline-path / --target-path pointing to a parent + directory that contains a single subdirectory with .pt files.""" + + def test_auto_descend_single_engine(self, tmp_path: Path, capsys) -> None: + """Parent dir wrapping a single engine subdir is auto-descended and comparison succeeds.""" + baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) + + baseline_wrapper: Path = tmp_path / "baseline_wrap" + target_wrapper: Path = tmp_path / "target_wrap" + baseline_wrapper.mkdir() + target_wrapper.mkdir() + baseline_exp.rename(baseline_wrapper / "engine_0") + target_exp.rename(target_wrapper / "engine_0") + + argv = _make_argv(baseline_wrapper, target_wrapper, preset="raw") + records, exit_code = _run_and_parse(argv, capsys) + + assert exit_code == 0 + _assert_single_comparison_passed(records) + + def test_no_descend_when_pt_at_root(self, tmp_path: Path, capsys) -> None: + """Direct .pt files — no descend needed, comparison still works.""" + baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) + + argv = _make_argv(baseline_exp, target_exp, preset="raw") + records, exit_code = _run_and_parse(argv, capsys) + + assert exit_code == 0 + _assert_single_comparison_passed(records) + + def test_auto_descend_emits_log_record(self, tmp_path: Path, capsys) -> None: + """Auto-descend emits a LogRecord with the info message.""" + baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) + + wrapper: Path = tmp_path / "target_wrap" + wrapper.mkdir() + target_exp.rename(wrapper / "engine_0") + + argv = _make_argv(baseline_exp, wrapper, preset="raw") + records, _ = _run_and_parse(argv, capsys) + + log_records: list[LogRecord] = [r for r in records if isinstance(r, LogRecord)] + auto_descend_msgs: list[str] = [ + info.message + for lr in log_records + for info in lr.infos + if "auto-descend" in info.message + ] + assert any("target_path" in m for m in auto_descend_msgs) + + def test_auto_descend_single_nonempty_among_empty( + self, tmp_path: Path, capsys + ) -> None: + """Two subdirs but only one has .pt — auto-descend picks the non-empty one.""" + baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) + + wrapper: Path = tmp_path / "target_wrap" + wrapper.mkdir() + target_exp.rename(wrapper / "engine_0") + (wrapper / "empty_subdir").mkdir() + + argv = _make_argv(baseline_exp, wrapper, preset="raw") + records, exit_code = _run_and_parse(argv, capsys) + + assert exit_code == 0 + _assert_single_comparison_passed(records) + + def test_error_multiple_nonempty_subdirs(self, tmp_path: Path) -> None: + """Two subdirs both with .pt — raises ValueError with clear message.""" + baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) + + wrapper: Path = tmp_path / "target_wrap" + wrapper.mkdir() + target_exp.rename(wrapper / "engine_0") + engine_1: Path = wrapper / "engine_1" + engine_1.mkdir() + torch.save(torch.tensor([1.0]), engine_1 / "dummy.pt") + + argv: list[str] = _make_argv(baseline_exp, wrapper, preset="raw") + with pytest.raises(ValueError, match="multiple subdirectories contain data"): + run(parse_args(argv)) + + def test_error_no_data_found(self, tmp_path: Path) -> None: + """No .pt files anywhere — raises ValueError.""" + baseline_exp, _ = _create_dumps(tmp_path, ["tensor_a"]) + + empty_dir: Path = tmp_path / "empty_target" + empty_dir.mkdir() + (empty_dir / "subdir").mkdir() + + argv: list[str] = _make_argv(baseline_exp, empty_dir, preset="raw") + with pytest.raises(ValueError, match="no .pt files found"): + run(parse_args(argv)) + + if __name__ == "__main__": sys.exit(pytest.main([__file__]))