diff --git a/scripts/conformance.py b/scripts/conformance.py index dc8348bcc3095..81da51fe68887 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -40,10 +40,10 @@ import tomllib from collections.abc import Sequence, Set as AbstractSet from dataclasses import dataclass -from enum import Flag, StrEnum, auto +from enum import StrEnum, auto from functools import reduce from itertools import chain, groupby -from operator import attrgetter, or_ +from operator import attrgetter from pathlib import Path from textwrap import dedent from typing import Any, Literal, Self, assert_never @@ -55,8 +55,6 @@ # on a set of lines with a matching tag # 4. Tagged multi-errors (E[tag+]): The type checker should raise one or # more errors on a set of lines with a matching tag -# This regex pattern parses the error lines in the conformance tests, -# but the following implementation currently ignores error tags. CONFORMANCE_ERROR_PATTERN = re.compile( r""" \#\s*E # "# E" begins each error @@ -81,7 +79,7 @@ CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "tests/{filename}#L{line}" -class Source(Flag): +class Source(StrEnum): OLD = auto() NEW = auto() EXPECTED = auto() @@ -105,6 +103,15 @@ def into_title(self) -> str: return "True positives removed" +@dataclass(kw_only=True, slots=True) +class Evaluation: + classification: Classification + true_positives: int = 0 + false_positives: int = 0 + true_negatives: int = 0 + false_negatives: int = 0 + + class Change(StrEnum): ADDED = auto() REMOVED = auto() @@ -152,10 +159,13 @@ class Diagnostic: check_name: str description: str severity: str - fingerprint: str | None location: Location source: Source optional: bool + # tag identifying an error that can occur on multiple lines + tag: str | None + # True if one or more errors can occur on lines with the same tag + multi: bool def __post_init__(self, *args, **kwargs) -> None: # Remove check name prefix from description @@ -180,7 +190,6 @@ def from_gitlab_output( check_name=dct["check_name"], description=dct["description"], severity=dct["severity"], - fingerprint=dct["fingerprint"], location=Location( path=Path(dct["location"]["path"]).resolve(), positions=Positions( @@ -196,12 +205,18 @@ def from_gitlab_output( ), source=source, optional=False, + tag=None, + multi=False, ) @property def key(self) -> str: - """Key to group diagnostics by path and beginning line.""" - return f"{self.location.path.as_posix()}:{self.location.positions.begin.line}" + """Key to group diagnostics by path and beginning line or path and tag.""" + return ( + f"{self.location.path.as_posix()}:{self.location.positions.begin.line}" + if self.tag is None + else f"{self.location.path.as_posix()}:{self.tag}" + ) @property def severity_for_display(self) -> str: @@ -214,21 +229,10 @@ def severity_for_display(self) -> str: @dataclass(kw_only=True, slots=True) class GroupedDiagnostics: key: str - sources: Source - old: Diagnostic | None - new: Diagnostic | None - expected: Diagnostic | None - - @property - def classification(self) -> Classification: - if Source.NEW in self.sources and Source.EXPECTED in self.sources: - return Classification.TRUE_POSITIVE - elif Source.NEW in self.sources and Source.EXPECTED not in self.sources: - return Classification.FALSE_POSITIVE - elif Source.EXPECTED in self.sources: - return Classification.FALSE_NEGATIVE - else: - return Classification.TRUE_NEGATIVE + sources: AbstractSet[Source] + old: list[Diagnostic] + new: list[Diagnostic] + expected: list[Diagnostic] @property def change(self) -> Change: @@ -241,17 +245,133 @@ def change(self) -> Change: @property def optional(self) -> bool: - return self.expected is not None and self.expected.optional + return bool(self.expected) and all( + diagnostic.optional for diagnostic in self.expected + ) + + @property + def multi(self) -> bool: + return bool(self.expected) and all( + diagnostic.multi for diagnostic in self.expected + ) + + def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: + match source: + case Source.NEW: + return self.new + case Source.OLD: + return self.old + case Source.EXPECTED: + return self.expected + + def classify(self, source: Source) -> Evaluation: + diagnostics = self.diagnostics_by_source(source) + + if source in self.sources: + if self.optional: + return Evaluation( + classification=Classification.TRUE_POSITIVE, + true_positives=len(diagnostics), + false_positives=0, + true_negatives=0, + false_negatives=0, + ) + + if Source.EXPECTED in self.sources: + distinct_lines = len( + { + diagnostic.location.positions.begin.line + for diagnostic in diagnostics + } + ) + expected_max = len(self.expected) if self.multi else 1 + + if 1 <= distinct_lines <= expected_max: + return Evaluation( + classification=Classification.TRUE_POSITIVE, + true_positives=len(diagnostics), + false_positives=0, + true_negatives=0, + false_negatives=0, + ) + else: + # We select the line with the most diagnostics + # as our true positive, while the rest are false positives + max_line = max( + groupby( + diagnostics, key=lambda d: d.location.positions.begin.line + ), + key=lambda x: len(x[1]), + ) + remaining = len(diagnostics) - max_line + # We can never exceed the number of distinct lines + # if the diagnostic is multi, so we ignore that case + return Evaluation( + classification=Classification.FALSE_POSITIVE, + true_positives=max_line, + false_positives=remaining, + true_negatives=0, + false_negatives=0, + ) + else: + return Evaluation( + classification=Classification.FALSE_POSITIVE, + true_positives=0, + false_positives=len(diagnostics), + true_negatives=0, + false_negatives=0, + ) + + elif Source.EXPECTED in self.sources: + if self.optional: + return Evaluation( + classification=Classification.TRUE_NEGATIVE, + true_positives=0, + false_positives=0, + true_negatives=len(diagnostics), + false_negatives=0, + ) + return Evaluation( + classification=Classification.FALSE_NEGATIVE, + true_positives=0, + false_positives=0, + true_negatives=0, + false_negatives=1, + ) - def _render_row(self, diagnostic: Diagnostic): - return f"| {diagnostic.location.as_link()} | {diagnostic.check_name} | {diagnostic.description} |" + else: + return Evaluation( + classification=Classification.TRUE_NEGATIVE, + true_positives=0, + false_positives=0, + true_negatives=1, + false_negatives=0, + ) + + def _render_row(self, diagnostics: list[Diagnostic]): + locs = [] + check_names = [] + descriptions = [] - def _render_diff(self, diagnostic: Diagnostic, *, removed: bool = False): + for diagnostic in diagnostics: + loc = ( + diagnostic.location.as_link() + if diagnostic.location + else f"`{diagnostic.tag}`" + ) + locs.append(loc) + check_names.append(diagnostic.check_name) + descriptions.append(diagnostic.description) + + return f"| {'
'.join(locs)} | {'
'.join(check_names)} | {'
'.join(descriptions)} |" + + def _render_diff(self, diagnostics: list[Diagnostic], *, removed: bool = False): sign = "-" if removed else "+" - return f"{sign} {diagnostic}" + return "\n".join(f"{sign} {diagnostic}" for diagnostic in diagnostics) def display(self, format: Literal["diff", "github"]) -> str: - match self.classification: + eval = self.classify(Source.NEW) + match eval.classification: case Classification.TRUE_POSITIVE | Classification.FALSE_POSITIVE: assert self.new is not None return ( @@ -261,23 +381,21 @@ def display(self, format: Literal["diff", "github"]) -> str: ) case Classification.FALSE_NEGATIVE | Classification.TRUE_NEGATIVE: - diagnostic = self.old or self.expected - assert diagnostic is not None + diagnostics = self.old if self.old else self.expected + return ( - self._render_diff(diagnostic, removed=True) + self._render_diff(diagnostics, removed=True) if format == "diff" - else self._render_row(diagnostic) + else self._render_row(diagnostics) ) - case _: - raise ValueError(f"Unexpected classification: {self.classification}") - @dataclass(kw_only=True, slots=True) class Statistics: true_positives: int = 0 false_positives: int = 0 false_negatives: int = 0 + total_diagnostics: int = 0 @property def precision(self) -> float: @@ -292,10 +410,6 @@ def recall(self) -> float: else: return 0.0 - @property - def total(self) -> int: - return self.true_positives + self.false_positives - def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic]: diagnostics: list[Diagnostic] = [] @@ -305,13 +419,8 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic] diagnostics.append( Diagnostic( check_name="conformance", - description=( - error.group("description") - or error.group("tag") - or "Missing" - ), + description=(error.group("description") or "Missing"), severity="major", - fingerprint=None, location=Location( path=file, positions=Positions( @@ -327,6 +436,12 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic] ), source=Source.EXPECTED, optional=error.group("optional") is not None, + tag=( + f"{file.name}:{error.group('tag')}" + if error.group("tag") + else None + ), + multi=error.group("multi") is not None, ) ) @@ -367,62 +482,72 @@ def collect_ty_diagnostics( def group_diagnostics_by_key( - old: list[Diagnostic], new: list[Diagnostic], expected: list[Diagnostic] + old: list[Diagnostic], + new: list[Diagnostic], + expected: list[Diagnostic], ) -> list[GroupedDiagnostics]: + # propagate tags from expected diagnostics to old and new diagnostics + tagged_lines = { + (d.location.path.name, d.location.positions.begin.line): d.tag + for d in expected + if d.tag is not None + } + + for diag in chain(old, new): + diag.tag = tagged_lines.get( + (diag.location.path.name, diag.location.positions.begin.line), None + ) + diagnostics = [ *old, *new, *expected, ] - sorted_diagnostics = sorted(diagnostics, key=attrgetter("key")) - - grouped = [] - for key, group in groupby(sorted_diagnostics, key=attrgetter("key")): - group = list(group) - sources: Source = reduce(or_, (diag.source for diag in group)) - grouped.append( - GroupedDiagnostics( - key=key, - sources=sources, - old=next(filter(lambda diag: diag.source == Source.OLD, group), None), - new=next(filter(lambda diag: diag.source == Source.NEW, group), None), - expected=next( - filter(lambda diag: diag.source == Source.EXPECTED, group), None - ), - ) + diagnostics = sorted(diagnostics, key=attrgetter("key")) + grouped_diagnostics = [] + for key, group in groupby(diagnostics, key=attrgetter("key")): + old_diagnostics: list[Diagnostic] = [] + new_diagnostics: list[Diagnostic] = [] + expected_diagnostics: list[Diagnostic] = [] + sources: set[Source] = set() + + for diag in group: + sources.add(diag.source) + match diag.source: + case Source.OLD: + old_diagnostics.append(diag) + case Source.NEW: + new_diagnostics.append(diag) + case Source.EXPECTED: + expected_diagnostics.append(diag) + + grouped = GroupedDiagnostics( + key=key, + sources=sources, + old=old_diagnostics, + new=new_diagnostics, + expected=expected_diagnostics, ) + grouped_diagnostics.append(grouped) - return grouped + return grouped_diagnostics def compute_stats( grouped_diagnostics: list[GroupedDiagnostics], - source: Source, + ty_version: Literal["new", "old"], ) -> Statistics: - if source == source.EXPECTED: - # ty currently raises a false positive here due to incomplete enum.Flag support - # see https://github.com/astral-sh/ty/issues/876 - num_errors = sum( - 1 - for g in grouped_diagnostics - if source.EXPECTED in g.sources # ty:ignore[unsupported-operator] - ) - return Statistics( - true_positives=num_errors, false_positives=0, false_negatives=0 - ) + source = Source.NEW if ty_version == "new" else Source.OLD def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics: - if (source in grouped.sources) and (Source.EXPECTED in grouped.sources): - statistics.true_positives += 1 - elif source in grouped.sources: - statistics.false_positives += 1 - elif Source.EXPECTED in grouped.sources: - statistics.false_negatives += 1 + eval = grouped.classify(source) + statistics.true_positives += eval.true_positives + statistics.false_positives += eval.false_positives + statistics.false_negatives += eval.false_negatives + statistics.total_diagnostics += len(grouped.diagnostics_by_source(source)) return statistics - grouped_diagnostics = [diag for diag in grouped_diagnostics if not diag.optional] - return reduce(increment, grouped_diagnostics, Statistics()) @@ -438,7 +563,9 @@ def render_grouped_diagnostics( ] get_change = attrgetter("change") - get_classification = attrgetter("classification") + + def get_classification(diag) -> Classification: + return diag.classify(Source.NEW).classification optional_diagnostics = sorted( (diag for diag in grouped if diag.optional), @@ -524,8 +651,8 @@ def format_metric(diff: float, old: float, new: float): return f"decreased from {old:.2%} to {new:.2%}" return f"held steady at {old:.2%}" - old = compute_stats(grouped_diagnostics, source=Source.OLD) - new = compute_stats(grouped_diagnostics, source=Source.NEW) + old = compute_stats(grouped_diagnostics, ty_version="old") + new = compute_stats(grouped_diagnostics, ty_version="new") assert new.true_positives > 0, ( "Expected ty to have at least one true positive.\n" @@ -537,7 +664,7 @@ def format_metric(diff: float, old: float, new: float): true_pos_change = new.true_positives - old.true_positives false_pos_change = new.false_positives - old.false_positives false_neg_change = new.false_negatives - old.false_negatives - total_change = new.total - old.total + total_change = new.total_diagnostics - old.total_diagnostics base_header = f"[Typing conformance results]({CONFORMANCE_DIR_WITH_README})" @@ -590,7 +717,7 @@ def format_metric(diff: float, old: float, new: float): | True Positives | {old.true_positives} | {new.true_positives} | {true_pos_change:+} | {true_pos_diff} | | False Positives | {old.false_positives} | {new.false_positives} | {false_pos_change:+} | {false_pos_diff} | | False Negatives | {old.false_negatives} | {new.false_negatives} | {false_neg_change:+} | {false_neg_diff} | - | Total Diagnostics | {old.total} | {new.total} | {total_change:+} | {total_diff} | + | Total Diagnostics | {old.total_diagnostics} | {new.total_diagnostics} | {total_change:+} | {total_diff} | | Precision | {old.precision:.2%} | {new.precision:.2%} | {precision_change:+.2%} | {precision_diff} | | Recall | {old.recall:.2%} | {new.recall:.2%} | {recall_change:+.2%} | {recall_diff} | @@ -631,6 +758,7 @@ def parse_args(): "--old-ty", nargs="+", help="Command to run old version of ty", + required=True, ) parser.add_argument( @@ -678,9 +806,6 @@ def parse_args(): args = parser.parse_args() - if args.old_ty is None: - raise ValueError("old_ty is required") - return args @@ -689,6 +814,7 @@ def main(): tests_dir = args.tests_path.resolve().absolute() test_groups = get_test_groups(tests_dir) test_files = get_test_cases(test_groups, tests_dir / "tests") + expected = collect_expected_diagnostics(test_files) old = collect_ty_diagnostics(