diff --git a/scripts/conformance.py b/scripts/conformance.py
index dc8348bcc3095..81da51fe68887 100644
--- a/scripts/conformance.py
+++ b/scripts/conformance.py
@@ -40,10 +40,10 @@
import tomllib
from collections.abc import Sequence, Set as AbstractSet
from dataclasses import dataclass
-from enum import Flag, StrEnum, auto
+from enum import StrEnum, auto
from functools import reduce
from itertools import chain, groupby
-from operator import attrgetter, or_
+from operator import attrgetter
from pathlib import Path
from textwrap import dedent
from typing import Any, Literal, Self, assert_never
@@ -55,8 +55,6 @@
# on a set of lines with a matching tag
# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or
# more errors on a set of lines with a matching tag
-# This regex pattern parses the error lines in the conformance tests,
-# but the following implementation currently ignores error tags.
CONFORMANCE_ERROR_PATTERN = re.compile(
r"""
\#\s*E # "# E" begins each error
@@ -81,7 +79,7 @@
CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "tests/{filename}#L{line}"
-class Source(Flag):
+class Source(StrEnum):
OLD = auto()
NEW = auto()
EXPECTED = auto()
@@ -105,6 +103,15 @@ def into_title(self) -> str:
return "True positives removed"
+@dataclass(kw_only=True, slots=True)
+class Evaluation:
+ classification: Classification
+ true_positives: int = 0
+ false_positives: int = 0
+ true_negatives: int = 0
+ false_negatives: int = 0
+
+
class Change(StrEnum):
ADDED = auto()
REMOVED = auto()
@@ -152,10 +159,13 @@ class Diagnostic:
check_name: str
description: str
severity: str
- fingerprint: str | None
location: Location
source: Source
optional: bool
+ # tag identifying an error that can occur on multiple lines
+ tag: str | None
+ # True if one or more errors can occur on lines with the same tag
+ multi: bool
def __post_init__(self, *args, **kwargs) -> None:
# Remove check name prefix from description
@@ -180,7 +190,6 @@ def from_gitlab_output(
check_name=dct["check_name"],
description=dct["description"],
severity=dct["severity"],
- fingerprint=dct["fingerprint"],
location=Location(
path=Path(dct["location"]["path"]).resolve(),
positions=Positions(
@@ -196,12 +205,18 @@ def from_gitlab_output(
),
source=source,
optional=False,
+ tag=None,
+ multi=False,
)
@property
def key(self) -> str:
- """Key to group diagnostics by path and beginning line."""
- return f"{self.location.path.as_posix()}:{self.location.positions.begin.line}"
+ """Key to group diagnostics by path and beginning line or path and tag."""
+ return (
+ f"{self.location.path.as_posix()}:{self.location.positions.begin.line}"
+ if self.tag is None
+ else f"{self.location.path.as_posix()}:{self.tag}"
+ )
@property
def severity_for_display(self) -> str:
@@ -214,21 +229,10 @@ def severity_for_display(self) -> str:
@dataclass(kw_only=True, slots=True)
class GroupedDiagnostics:
key: str
- sources: Source
- old: Diagnostic | None
- new: Diagnostic | None
- expected: Diagnostic | None
-
- @property
- def classification(self) -> Classification:
- if Source.NEW in self.sources and Source.EXPECTED in self.sources:
- return Classification.TRUE_POSITIVE
- elif Source.NEW in self.sources and Source.EXPECTED not in self.sources:
- return Classification.FALSE_POSITIVE
- elif Source.EXPECTED in self.sources:
- return Classification.FALSE_NEGATIVE
- else:
- return Classification.TRUE_NEGATIVE
+ sources: AbstractSet[Source]
+ old: list[Diagnostic]
+ new: list[Diagnostic]
+ expected: list[Diagnostic]
@property
def change(self) -> Change:
@@ -241,17 +245,133 @@ def change(self) -> Change:
@property
def optional(self) -> bool:
- return self.expected is not None and self.expected.optional
+ return bool(self.expected) and all(
+ diagnostic.optional for diagnostic in self.expected
+ )
+
+ @property
+ def multi(self) -> bool:
+ return bool(self.expected) and all(
+ diagnostic.multi for diagnostic in self.expected
+ )
+
+ def diagnostics_by_source(self, source: Source) -> list[Diagnostic]:
+ match source:
+ case Source.NEW:
+ return self.new
+ case Source.OLD:
+ return self.old
+ case Source.EXPECTED:
+ return self.expected
+
+ def classify(self, source: Source) -> Evaluation:
+ diagnostics = self.diagnostics_by_source(source)
+
+ if source in self.sources:
+ if self.optional:
+ return Evaluation(
+ classification=Classification.TRUE_POSITIVE,
+ true_positives=len(diagnostics),
+ false_positives=0,
+ true_negatives=0,
+ false_negatives=0,
+ )
+
+ if Source.EXPECTED in self.sources:
+ distinct_lines = len(
+ {
+ diagnostic.location.positions.begin.line
+ for diagnostic in diagnostics
+ }
+ )
+ expected_max = len(self.expected) if self.multi else 1
+
+ if 1 <= distinct_lines <= expected_max:
+ return Evaluation(
+ classification=Classification.TRUE_POSITIVE,
+ true_positives=len(diagnostics),
+ false_positives=0,
+ true_negatives=0,
+ false_negatives=0,
+ )
+ else:
+ # We select the line with the most diagnostics
+ # as our true positive, while the rest are false positives
+ max_line = max(
+ groupby(
+ diagnostics, key=lambda d: d.location.positions.begin.line
+ ),
+ key=lambda x: len(x[1]),
+ )
+ remaining = len(diagnostics) - max_line
+ # We can never exceed the number of distinct lines
+ # if the diagnostic is multi, so we ignore that case
+ return Evaluation(
+ classification=Classification.FALSE_POSITIVE,
+ true_positives=max_line,
+ false_positives=remaining,
+ true_negatives=0,
+ false_negatives=0,
+ )
+ else:
+ return Evaluation(
+ classification=Classification.FALSE_POSITIVE,
+ true_positives=0,
+ false_positives=len(diagnostics),
+ true_negatives=0,
+ false_negatives=0,
+ )
+
+ elif Source.EXPECTED in self.sources:
+ if self.optional:
+ return Evaluation(
+ classification=Classification.TRUE_NEGATIVE,
+ true_positives=0,
+ false_positives=0,
+ true_negatives=len(diagnostics),
+ false_negatives=0,
+ )
+ return Evaluation(
+ classification=Classification.FALSE_NEGATIVE,
+ true_positives=0,
+ false_positives=0,
+ true_negatives=0,
+ false_negatives=1,
+ )
- def _render_row(self, diagnostic: Diagnostic):
- return f"| {diagnostic.location.as_link()} | {diagnostic.check_name} | {diagnostic.description} |"
+ else:
+ return Evaluation(
+ classification=Classification.TRUE_NEGATIVE,
+ true_positives=0,
+ false_positives=0,
+ true_negatives=1,
+ false_negatives=0,
+ )
+
+ def _render_row(self, diagnostics: list[Diagnostic]):
+ locs = []
+ check_names = []
+ descriptions = []
- def _render_diff(self, diagnostic: Diagnostic, *, removed: bool = False):
+ for diagnostic in diagnostics:
+ loc = (
+ diagnostic.location.as_link()
+ if diagnostic.location
+ else f"`{diagnostic.tag}`"
+ )
+ locs.append(loc)
+ check_names.append(diagnostic.check_name)
+ descriptions.append(diagnostic.description)
+
+ return f"| {'
'.join(locs)} | {'
'.join(check_names)} | {'
'.join(descriptions)} |"
+
+ def _render_diff(self, diagnostics: list[Diagnostic], *, removed: bool = False):
sign = "-" if removed else "+"
- return f"{sign} {diagnostic}"
+ return "\n".join(f"{sign} {diagnostic}" for diagnostic in diagnostics)
def display(self, format: Literal["diff", "github"]) -> str:
- match self.classification:
+ eval = self.classify(Source.NEW)
+ match eval.classification:
case Classification.TRUE_POSITIVE | Classification.FALSE_POSITIVE:
assert self.new is not None
return (
@@ -261,23 +381,21 @@ def display(self, format: Literal["diff", "github"]) -> str:
)
case Classification.FALSE_NEGATIVE | Classification.TRUE_NEGATIVE:
- diagnostic = self.old or self.expected
- assert diagnostic is not None
+ diagnostics = self.old if self.old else self.expected
+
return (
- self._render_diff(diagnostic, removed=True)
+ self._render_diff(diagnostics, removed=True)
if format == "diff"
- else self._render_row(diagnostic)
+ else self._render_row(diagnostics)
)
- case _:
- raise ValueError(f"Unexpected classification: {self.classification}")
-
@dataclass(kw_only=True, slots=True)
class Statistics:
true_positives: int = 0
false_positives: int = 0
false_negatives: int = 0
+ total_diagnostics: int = 0
@property
def precision(self) -> float:
@@ -292,10 +410,6 @@ def recall(self) -> float:
else:
return 0.0
- @property
- def total(self) -> int:
- return self.true_positives + self.false_positives
-
def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic]:
diagnostics: list[Diagnostic] = []
@@ -305,13 +419,8 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic]
diagnostics.append(
Diagnostic(
check_name="conformance",
- description=(
- error.group("description")
- or error.group("tag")
- or "Missing"
- ),
+ description=(error.group("description") or "Missing"),
severity="major",
- fingerprint=None,
location=Location(
path=file,
positions=Positions(
@@ -327,6 +436,12 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic]
),
source=Source.EXPECTED,
optional=error.group("optional") is not None,
+ tag=(
+ f"{file.name}:{error.group('tag')}"
+ if error.group("tag")
+ else None
+ ),
+ multi=error.group("multi") is not None,
)
)
@@ -367,62 +482,72 @@ def collect_ty_diagnostics(
def group_diagnostics_by_key(
- old: list[Diagnostic], new: list[Diagnostic], expected: list[Diagnostic]
+ old: list[Diagnostic],
+ new: list[Diagnostic],
+ expected: list[Diagnostic],
) -> list[GroupedDiagnostics]:
+ # propagate tags from expected diagnostics to old and new diagnostics
+ tagged_lines = {
+ (d.location.path.name, d.location.positions.begin.line): d.tag
+ for d in expected
+ if d.tag is not None
+ }
+
+ for diag in chain(old, new):
+ diag.tag = tagged_lines.get(
+ (diag.location.path.name, diag.location.positions.begin.line), None
+ )
+
diagnostics = [
*old,
*new,
*expected,
]
- sorted_diagnostics = sorted(diagnostics, key=attrgetter("key"))
-
- grouped = []
- for key, group in groupby(sorted_diagnostics, key=attrgetter("key")):
- group = list(group)
- sources: Source = reduce(or_, (diag.source for diag in group))
- grouped.append(
- GroupedDiagnostics(
- key=key,
- sources=sources,
- old=next(filter(lambda diag: diag.source == Source.OLD, group), None),
- new=next(filter(lambda diag: diag.source == Source.NEW, group), None),
- expected=next(
- filter(lambda diag: diag.source == Source.EXPECTED, group), None
- ),
- )
+ diagnostics = sorted(diagnostics, key=attrgetter("key"))
+ grouped_diagnostics = []
+ for key, group in groupby(diagnostics, key=attrgetter("key")):
+ old_diagnostics: list[Diagnostic] = []
+ new_diagnostics: list[Diagnostic] = []
+ expected_diagnostics: list[Diagnostic] = []
+ sources: set[Source] = set()
+
+ for diag in group:
+ sources.add(diag.source)
+ match diag.source:
+ case Source.OLD:
+ old_diagnostics.append(diag)
+ case Source.NEW:
+ new_diagnostics.append(diag)
+ case Source.EXPECTED:
+ expected_diagnostics.append(diag)
+
+ grouped = GroupedDiagnostics(
+ key=key,
+ sources=sources,
+ old=old_diagnostics,
+ new=new_diagnostics,
+ expected=expected_diagnostics,
)
+ grouped_diagnostics.append(grouped)
- return grouped
+ return grouped_diagnostics
def compute_stats(
grouped_diagnostics: list[GroupedDiagnostics],
- source: Source,
+ ty_version: Literal["new", "old"],
) -> Statistics:
- if source == source.EXPECTED:
- # ty currently raises a false positive here due to incomplete enum.Flag support
- # see https://github.com/astral-sh/ty/issues/876
- num_errors = sum(
- 1
- for g in grouped_diagnostics
- if source.EXPECTED in g.sources # ty:ignore[unsupported-operator]
- )
- return Statistics(
- true_positives=num_errors, false_positives=0, false_negatives=0
- )
+ source = Source.NEW if ty_version == "new" else Source.OLD
def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics:
- if (source in grouped.sources) and (Source.EXPECTED in grouped.sources):
- statistics.true_positives += 1
- elif source in grouped.sources:
- statistics.false_positives += 1
- elif Source.EXPECTED in grouped.sources:
- statistics.false_negatives += 1
+ eval = grouped.classify(source)
+ statistics.true_positives += eval.true_positives
+ statistics.false_positives += eval.false_positives
+ statistics.false_negatives += eval.false_negatives
+ statistics.total_diagnostics += len(grouped.diagnostics_by_source(source))
return statistics
- grouped_diagnostics = [diag for diag in grouped_diagnostics if not diag.optional]
-
return reduce(increment, grouped_diagnostics, Statistics())
@@ -438,7 +563,9 @@ def render_grouped_diagnostics(
]
get_change = attrgetter("change")
- get_classification = attrgetter("classification")
+
+ def get_classification(diag) -> Classification:
+ return diag.classify(Source.NEW).classification
optional_diagnostics = sorted(
(diag for diag in grouped if diag.optional),
@@ -524,8 +651,8 @@ def format_metric(diff: float, old: float, new: float):
return f"decreased from {old:.2%} to {new:.2%}"
return f"held steady at {old:.2%}"
- old = compute_stats(grouped_diagnostics, source=Source.OLD)
- new = compute_stats(grouped_diagnostics, source=Source.NEW)
+ old = compute_stats(grouped_diagnostics, ty_version="old")
+ new = compute_stats(grouped_diagnostics, ty_version="new")
assert new.true_positives > 0, (
"Expected ty to have at least one true positive.\n"
@@ -537,7 +664,7 @@ def format_metric(diff: float, old: float, new: float):
true_pos_change = new.true_positives - old.true_positives
false_pos_change = new.false_positives - old.false_positives
false_neg_change = new.false_negatives - old.false_negatives
- total_change = new.total - old.total
+ total_change = new.total_diagnostics - old.total_diagnostics
base_header = f"[Typing conformance results]({CONFORMANCE_DIR_WITH_README})"
@@ -590,7 +717,7 @@ def format_metric(diff: float, old: float, new: float):
| True Positives | {old.true_positives} | {new.true_positives} | {true_pos_change:+} | {true_pos_diff} |
| False Positives | {old.false_positives} | {new.false_positives} | {false_pos_change:+} | {false_pos_diff} |
| False Negatives | {old.false_negatives} | {new.false_negatives} | {false_neg_change:+} | {false_neg_diff} |
- | Total Diagnostics | {old.total} | {new.total} | {total_change:+} | {total_diff} |
+ | Total Diagnostics | {old.total_diagnostics} | {new.total_diagnostics} | {total_change:+} | {total_diff} |
| Precision | {old.precision:.2%} | {new.precision:.2%} | {precision_change:+.2%} | {precision_diff} |
| Recall | {old.recall:.2%} | {new.recall:.2%} | {recall_change:+.2%} | {recall_diff} |
@@ -631,6 +758,7 @@ def parse_args():
"--old-ty",
nargs="+",
help="Command to run old version of ty",
+ required=True,
)
parser.add_argument(
@@ -678,9 +806,6 @@ def parse_args():
args = parser.parse_args()
- if args.old_ty is None:
- raise ValueError("old_ty is required")
-
return args
@@ -689,6 +814,7 @@ def main():
tests_dir = args.tests_path.resolve().absolute()
test_groups = get_test_groups(tests_dir)
test_files = get_test_cases(test_groups, tests_dir / "tests")
+
expected = collect_expected_diagnostics(test_files)
old = collect_ty_diagnostics(