From e8777ca7df5dc0359eacd2b3806b98aed72769ee Mon Sep 17 00:00:00 2001 From: Will Duke Date: Sun, 18 Jan 2026 09:49:06 +0000 Subject: [PATCH 01/13] wip --- scripts/conformance.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/scripts/conformance.py b/scripts/conformance.py index dc8348bcc3095..320a97dd1f7fa 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -156,6 +156,10 @@ class Diagnostic: location: Location source: Source optional: bool + # {filename:tag} identifying an error that can occur on multiple lines + tag: str | None + # True if the error can occur on multiple lines or only once per tag + multi: bool def __post_init__(self, *args, **kwargs) -> None: # Remove check name prefix from description @@ -196,6 +200,8 @@ def from_gitlab_output( ), source=source, optional=False, + tag=None, + multi=False, ) @property @@ -327,6 +333,8 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic] ), source=Source.EXPECTED, optional=error.group("optional") is not None, + tag=f"{file.stem}:{error.group('tag')}", + multi=error.group("multi") is not None, ) ) @@ -396,6 +404,13 @@ def group_diagnostics_by_key( return grouped +def process_tagged_diagnostics(grouped_diagnostics: list[GroupedDiagnostics]): + """For each group of diagnostics, group once again by tag, and track + for each ty diagnostic whether the associated test case already has an expected + error. If a diagnostic occurs on + """ + + def compute_stats( grouped_diagnostics: list[GroupedDiagnostics], source: Source, From 8efd65d178279e5b31183022362fccc881cc5ac4 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Mon, 19 Jan 2026 21:31:39 +0000 Subject: [PATCH 02/13] [ty] Handle tagged errors in conformance --- scripts/conformance.py | 228 +++++++++++++++++++++++++++++++++-------- 1 file changed, 185 insertions(+), 43 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 320a97dd1f7fa..a0df16a02fe06 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -152,11 +152,10 @@ class Diagnostic: check_name: str description: str severity: str - fingerprint: str | None location: Location source: Source optional: bool - # {filename:tag} identifying an error that can occur on multiple lines + # tag identifying an error that can occur on multiple lines tag: str | None # True if the error can occur on multiple lines or only once per tag multi: bool @@ -184,7 +183,6 @@ def from_gitlab_output( check_name=dct["check_name"], description=dct["description"], severity=dct["severity"], - fingerprint=dct["fingerprint"], location=Location( path=Path(dct["location"]["path"]).resolve(), positions=Positions( @@ -221,9 +219,9 @@ def severity_for_display(self) -> str: class GroupedDiagnostics: key: str sources: Source - old: Diagnostic | None - new: Diagnostic | None - expected: Diagnostic | None + old: list[Diagnostic] | None + new: list[Diagnostic] | None + expected: list[Diagnostic] | None @property def classification(self) -> Classification: @@ -247,14 +245,41 @@ def change(self) -> Change: @property def optional(self) -> bool: - return self.expected is not None and self.expected.optional + return bool(self.expected) and all( + diagnostic.optional for diagnostic in self.expected + ) + + def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: + match source: + case Source.NEW: + return self.new or [] + case Source.OLD: + return self.old or [] + case Source.EXPECTED: + return self.expected or [] + case _: + raise ValueError(f"Invalid source: {source}") + + def _render_row(self, diagnostics: list[Diagnostic]): + locs = [] + check_names = [] + descriptions = [] + + for diagnostic in diagnostics: + loc = ( + diagnostic.location.as_link() + if diagnostic.location + else f"`{diagnostic.tag}`" + ) + locs.append(loc) + check_names.append(diagnostic.check_name) + descriptions.append(diagnostic.description) - def _render_row(self, diagnostic: Diagnostic): - return f"| {diagnostic.location.as_link()} | {diagnostic.check_name} | {diagnostic.description} |" + return f"| {'
'.join(locs)} | {'
'.join(check_names)} | {'
'.join(descriptions)} |" - def _render_diff(self, diagnostic: Diagnostic, *, removed: bool = False): + def _render_diff(self, diagnostics: list[Diagnostic], *, removed: bool = False): sign = "-" if removed else "+" - return f"{sign} {diagnostic}" + return "\n".join(f"{sign} {diagnostic}" for diagnostic in diagnostics) def display(self, format: Literal["diff", "github"]) -> str: match self.classification: @@ -267,12 +292,17 @@ def display(self, format: Literal["diff", "github"]) -> str: ) case Classification.FALSE_NEGATIVE | Classification.TRUE_NEGATIVE: - diagnostic = self.old or self.expected - assert diagnostic is not None + diagnostics = list( + filter( + lambda d: d is not None, + (*(self.old or []), *(self.expected or [])), + ) + ) + return ( - self._render_diff(diagnostic, removed=True) + self._render_diff(diagnostics, removed=True) if format == "diff" - else self._render_row(diagnostic) + else self._render_row(diagnostics) ) case _: @@ -284,6 +314,7 @@ class Statistics: true_positives: int = 0 false_positives: int = 0 false_negatives: int = 0 + total_diagnostics: int = 0 @property def precision(self) -> float: @@ -298,10 +329,6 @@ def recall(self) -> float: else: return 0.0 - @property - def total(self) -> int: - return self.true_positives + self.false_positives - def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic]: diagnostics: list[Diagnostic] = [] @@ -311,13 +338,8 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic] diagnostics.append( Diagnostic( check_name="conformance", - description=( - error.group("description") - or error.group("tag") - or "Missing" - ), + description=(error.group("description") or "Missing"), severity="major", - fingerprint=None, location=Location( path=file, positions=Positions( @@ -333,7 +355,9 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic] ), source=Source.EXPECTED, optional=error.group("optional") is not None, - tag=f"{file.stem}:{error.group('tag')}", + tag=f"{file.name}:{error.group('tag')}" + if error.group("tag") + else None, multi=error.group("multi") is not None, ) ) @@ -393,10 +417,10 @@ def group_diagnostics_by_key( GroupedDiagnostics( key=key, sources=sources, - old=next(filter(lambda diag: diag.source == Source.OLD, group), None), - new=next(filter(lambda diag: diag.source == Source.NEW, group), None), - expected=next( - filter(lambda diag: diag.source == Source.EXPECTED, group), None + old=list(filter(lambda diag: diag.source == Source.OLD, group)), + new=list(filter(lambda diag: diag.source == Source.NEW, group)), + expected=list( + filter(lambda diag: diag.source == Source.EXPECTED, group) ), ) ) @@ -404,11 +428,108 @@ def group_diagnostics_by_key( return grouped -def process_tagged_diagnostics(grouped_diagnostics: list[GroupedDiagnostics]): - """For each group of diagnostics, group once again by tag, and track - for each ty diagnostic whether the associated test case already has an expected - error. If a diagnostic occurs on - """ +def split_expected_by_tag( + expected: list[Diagnostic], +) -> tuple[list[Diagnostic], dict[str, list[Diagnostic]]]: + untagged: list[Diagnostic] = [] + tagged: dict[str, list[Diagnostic]] = {} + + for d in expected: + if d.tag is None: + untagged.append(d) + else: + tagged.setdefault(d.tag, []).append(d) + + return untagged, tagged + + +def index_observed_by_path( + diagnostics: list[Diagnostic], +) -> dict[Path, list[Diagnostic]]: + index: dict[Path, list[Diagnostic]] = {} + for d in diagnostics: + index.setdefault(d.location.path, []).append(d) + return index + + +def tagged_locations( + tagged_expected: dict[str, list[Diagnostic]], +) -> set[tuple[Path, int]]: + locations: set[tuple[Path, int]] = set() + for diags in tagged_expected.values(): + for d in diags: + locations.add((d.location.path, d.location.positions.begin.line)) + return locations + + +def filter_out_tagged( + diagnostics: list[Diagnostic], + *, + tagged_locs: set[tuple[Path, int]], +) -> list[Diagnostic]: + return [ + d + for d in diagnostics + if (d.location.path, d.location.positions.begin.line) not in tagged_locs + ] + + +def condense_tagged_groups( + *, + tagged_expected: dict[str, list[Diagnostic]], + old: list[Diagnostic], + new: list[Diagnostic], +) -> list[GroupedDiagnostics]: + old_by_path = index_observed_by_path(old) + new_by_path = index_observed_by_path(new) + + results: list[GroupedDiagnostics] = [] + + for tag, expected_diags in tagged_expected.items(): + exemplar = expected_diags[0] + path = exemplar.location.path + lines = {d.location.positions.begin.line for d in expected_diags} + + def count_hits(path: Path, lines: set[int], observed: list[Diagnostic]) -> int: + return sum( + 1 + for d in observed + if d.location.path == path and d.location.positions.begin.line in lines + ) + + old_hits = count_hits(path, lines, old_by_path.get(path, [])) + new_hits = count_hits(path, lines, new_by_path.get(path, [])) + + expected_max = len(expected_diags) if exemplar.multi else 1 + + sources = Source.EXPECTED + if 1 <= old_hits <= expected_max: + sources |= Source.OLD + if 1 <= new_hits <= expected_max: + sources |= Source.NEW + + old_diagnostics = [ + d + for d in old_by_path.get(path, []) + if d.location.positions.begin.line in lines + ] + new_diagnostics = [ + d + for d in new_by_path.get(path, []) + if d.location.positions.begin.line in lines + ] + + results.append( + GroupedDiagnostics( + key=f"{tag}", + sources=sources, + old=old_diagnostics, + new=new_diagnostics, + expected=expected_diags, + ) + ) + + return results def compute_stats( @@ -434,6 +555,8 @@ def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics statistics.false_positives += 1 elif Source.EXPECTED in grouped.sources: statistics.false_negatives += 1 + + statistics.total_diagnostics += len(grouped.diagnostics_by_source(source)) return statistics grouped_diagnostics = [diag for diag in grouped_diagnostics if not diag.optional] @@ -552,7 +675,7 @@ def format_metric(diff: float, old: float, new: float): true_pos_change = new.true_positives - old.true_positives false_pos_change = new.false_positives - old.false_positives false_neg_change = new.false_negatives - old.false_negatives - total_change = new.total - old.total + total_change = new.total_diagnostics - old.total_diagnostics base_header = f"[Typing conformance results]({CONFORMANCE_DIR_WITH_README})" @@ -605,7 +728,7 @@ def format_metric(diff: float, old: float, new: float): | True Positives | {old.true_positives} | {new.true_positives} | {true_pos_change:+} | {true_pos_diff} | | False Positives | {old.false_positives} | {new.false_positives} | {false_pos_change:+} | {false_pos_diff} | | False Negatives | {old.false_negatives} | {new.false_negatives} | {false_neg_change:+} | {false_neg_diff} | - | Total Diagnostics | {old.total} | {new.total} | {total_change:+} | {total_diff} | + | Total Diagnostics | {old.total_diagnostics} | {new.total_diagnostics} | {total_change:+} | {total_diff} | | Precision | {old.precision:.2%} | {new.precision:.2%} | {precision_change:+.2%} | {precision_diff} | | Recall | {old.recall:.2%} | {new.recall:.2%} | {recall_change:+.2%} | {recall_diff} | @@ -646,6 +769,7 @@ def parse_args(): "--old-ty", nargs="+", help="Command to run old version of ty", + required=True, ) parser.add_argument( @@ -693,9 +817,6 @@ def parse_args(): args = parser.parse_args() - if args.old_ty is None: - raise ValueError("old_ty is required") - return args @@ -704,7 +825,10 @@ def main(): tests_dir = args.tests_path.resolve().absolute() test_groups = get_test_groups(tests_dir) test_files = get_test_cases(test_groups, tests_dir / "tests") - expected = collect_expected_diagnostics(test_files) + + expected_all = collect_expected_diagnostics(test_files) + expected_untagged, expected_tagged = split_expected_by_tag(expected_all) + tagged_locs = tagged_locations(expected_tagged) old = collect_ty_diagnostics( ty_path=args.old_ty, @@ -720,12 +844,30 @@ def main(): python_version=args.python_version, ) - grouped = group_diagnostics_by_key( + old_untagged = filter_out_tagged( + old, + tagged_locs=tagged_locs, + ) + + new_untagged = filter_out_tagged( + new, + tagged_locs=tagged_locs, + ) + + grouped_tagged = condense_tagged_groups( + tagged_expected=expected_tagged, old=old, new=new, - expected=expected, ) + grouped_untagged = group_diagnostics_by_key( + old=old_untagged, + new=new_untagged, + expected=expected_untagged, + ) + + grouped = [*grouped_untagged, *grouped_tagged] + rendered = "\n\n".join( [ render_summary(grouped, force_summary_table=args.force_summary_table), From 6e31fc892037b99b8526567bdf07030b6a3a4596 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Wed, 21 Jan 2026 09:20:14 +0000 Subject: [PATCH 03/13] clean up and fix false positives bug --- scripts/conformance.py | 250 +++++++++++++---------------------------- 1 file changed, 80 insertions(+), 170 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index a0df16a02fe06..d9558bd45830b 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -40,10 +40,10 @@ import tomllib from collections.abc import Sequence, Set as AbstractSet from dataclasses import dataclass -from enum import Flag, StrEnum, auto +from enum import StrEnum, auto from functools import reduce from itertools import chain, groupby -from operator import attrgetter, or_ +from operator import attrgetter from pathlib import Path from textwrap import dedent from typing import Any, Literal, Self, assert_never @@ -81,7 +81,7 @@ CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "tests/{filename}#L{line}" -class Source(Flag): +class Source(StrEnum): OLD = auto() NEW = auto() EXPECTED = auto() @@ -205,7 +205,11 @@ def from_gitlab_output( @property def key(self) -> str: """Key to group diagnostics by path and beginning line.""" - return f"{self.location.path.as_posix()}:{self.location.positions.begin.line}" + return ( + f"{self.location.path.as_posix()}:{self.location.positions.begin.line}" + if self.tag is None + else f"{self.location.path.as_posix()}:{self.tag}" + ) @property def severity_for_display(self) -> str: @@ -218,22 +222,11 @@ def severity_for_display(self) -> str: @dataclass(kw_only=True, slots=True) class GroupedDiagnostics: key: str - sources: Source + sources: AbstractSet[Source] old: list[Diagnostic] | None new: list[Diagnostic] | None expected: list[Diagnostic] | None - @property - def classification(self) -> Classification: - if Source.NEW in self.sources and Source.EXPECTED in self.sources: - return Classification.TRUE_POSITIVE - elif Source.NEW in self.sources and Source.EXPECTED not in self.sources: - return Classification.FALSE_POSITIVE - elif Source.EXPECTED in self.sources: - return Classification.FALSE_NEGATIVE - else: - return Classification.TRUE_NEGATIVE - @property def change(self) -> Change: if Source.NEW in self.sources and Source.OLD not in self.sources: @@ -249,6 +242,12 @@ def optional(self) -> bool: diagnostic.optional for diagnostic in self.expected ) + @property + def multi(self) -> bool: + return bool(self.expected) and all( + diagnostic.multi for diagnostic in self.expected + ) + def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: match source: case Source.NEW: @@ -260,6 +259,25 @@ def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: case _: raise ValueError(f"Invalid source: {source}") + def classify(self, source: Source) -> Classification: + if source in self.sources and Source.EXPECTED in self.sources: + assert self.expected is not None + diagnostics = self.diagnostics_by_source(source) + expected_max = len(self.expected) if self.multi else 1 + if 1 <= len(diagnostics) <= expected_max: + return Classification.TRUE_POSITIVE + else: + return Classification.FALSE_POSITIVE + + elif source in self.sources and Source.EXPECTED not in self.sources: + return Classification.FALSE_POSITIVE + + elif Source.EXPECTED in self.sources: + return Classification.FALSE_NEGATIVE + + else: + return Classification.TRUE_NEGATIVE + def _render_row(self, diagnostics: list[Diagnostic]): locs = [] check_names = [] @@ -282,7 +300,7 @@ def _render_diff(self, diagnostics: list[Diagnostic], *, removed: bool = False): return "\n".join(f"{sign} {diagnostic}" for diagnostic in diagnostics) def display(self, format: Literal["diff", "github"]) -> str: - match self.classification: + match self.classify(Source.NEW): case Classification.TRUE_POSITIVE | Classification.FALSE_POSITIVE: assert self.new is not None return ( @@ -398,138 +416,55 @@ def collect_ty_diagnostics( ] -def group_diagnostics_by_key( - old: list[Diagnostic], new: list[Diagnostic], expected: list[Diagnostic] +def group_diagnostics_by_key_or_tag( + old: list[Diagnostic], + new: list[Diagnostic], + expected: list[Diagnostic], ) -> list[GroupedDiagnostics]: + # propagate tags from expected diagnostics to old and new diagnostics + tagged_lines = { + (d.location.path.name, d.location.positions.begin.line): d.tag + for d in expected + if d.tag is not None + } + + for diag in old: + diag.tag = tagged_lines.get( + (diag.location.path.name, diag.location.positions.begin.line), None + ) + + for diag in new: + diag.tag = tagged_lines.get( + (diag.location.path.name, diag.location.positions.begin.line), None + ) + diagnostics = [ *old, *new, *expected, ] - sorted_diagnostics = sorted(diagnostics, key=attrgetter("key")) - - grouped = [] - for key, group in groupby(sorted_diagnostics, key=attrgetter("key")): + # group diagnostics either by a path and a line or a path and a tag + diagnostics = sorted(diagnostics, key=attrgetter("key")) + grouped_diagnostics = [] + for key, group in groupby(diagnostics, key=attrgetter("key")): group = list(group) - sources: Source = reduce(or_, (diag.source for diag in group)) - grouped.append( - GroupedDiagnostics( - key=key, - sources=sources, - old=list(filter(lambda diag: diag.source == Source.OLD, group)), - new=list(filter(lambda diag: diag.source == Source.NEW, group)), - expected=list( - filter(lambda diag: diag.source == Source.EXPECTED, group) - ), - ) + old_group = list(filter(lambda diag: diag.source == Source.OLD, group)) + new_group = list(filter(lambda diag: diag.source == Source.NEW, group)) + expected_group = list( + filter(lambda diag: diag.source == Source.EXPECTED, group) ) - return grouped - - -def split_expected_by_tag( - expected: list[Diagnostic], -) -> tuple[list[Diagnostic], dict[str, list[Diagnostic]]]: - untagged: list[Diagnostic] = [] - tagged: dict[str, list[Diagnostic]] = {} - - for d in expected: - if d.tag is None: - untagged.append(d) - else: - tagged.setdefault(d.tag, []).append(d) - - return untagged, tagged - - -def index_observed_by_path( - diagnostics: list[Diagnostic], -) -> dict[Path, list[Diagnostic]]: - index: dict[Path, list[Diagnostic]] = {} - for d in diagnostics: - index.setdefault(d.location.path, []).append(d) - return index - - -def tagged_locations( - tagged_expected: dict[str, list[Diagnostic]], -) -> set[tuple[Path, int]]: - locations: set[tuple[Path, int]] = set() - for diags in tagged_expected.values(): - for d in diags: - locations.add((d.location.path, d.location.positions.begin.line)) - return locations - - -def filter_out_tagged( - diagnostics: list[Diagnostic], - *, - tagged_locs: set[tuple[Path, int]], -) -> list[Diagnostic]: - return [ - d - for d in diagnostics - if (d.location.path, d.location.positions.begin.line) not in tagged_locs - ] - - -def condense_tagged_groups( - *, - tagged_expected: dict[str, list[Diagnostic]], - old: list[Diagnostic], - new: list[Diagnostic], -) -> list[GroupedDiagnostics]: - old_by_path = index_observed_by_path(old) - new_by_path = index_observed_by_path(new) - - results: list[GroupedDiagnostics] = [] - - for tag, expected_diags in tagged_expected.items(): - exemplar = expected_diags[0] - path = exemplar.location.path - lines = {d.location.positions.begin.line for d in expected_diags} - - def count_hits(path: Path, lines: set[int], observed: list[Diagnostic]) -> int: - return sum( - 1 - for d in observed - if d.location.path == path and d.location.positions.begin.line in lines - ) - - old_hits = count_hits(path, lines, old_by_path.get(path, [])) - new_hits = count_hits(path, lines, new_by_path.get(path, [])) - - expected_max = len(expected_diags) if exemplar.multi else 1 - - sources = Source.EXPECTED - if 1 <= old_hits <= expected_max: - sources |= Source.OLD - if 1 <= new_hits <= expected_max: - sources |= Source.NEW - - old_diagnostics = [ - d - for d in old_by_path.get(path, []) - if d.location.positions.begin.line in lines - ] - new_diagnostics = [ - d - for d in new_by_path.get(path, []) - if d.location.positions.begin.line in lines - ] - - results.append( - GroupedDiagnostics( - key=f"{tag}", - sources=sources, - old=old_diagnostics, - new=new_diagnostics, - expected=expected_diags, - ) + grouped = GroupedDiagnostics( + key=key, + sources={d.source for d in group}, + old=old_group, + new=new_group, + expected=expected_group, ) + grouped_diagnostics.append(grouped) - return results + return grouped_diagnostics def compute_stats( @@ -537,23 +472,18 @@ def compute_stats( source: Source, ) -> Statistics: if source == source.EXPECTED: - # ty currently raises a false positive here due to incomplete enum.Flag support - # see https://github.com/astral-sh/ty/issues/876 - num_errors = sum( - 1 - for g in grouped_diagnostics - if source.EXPECTED in g.sources # ty:ignore[unsupported-operator] - ) + num_errors = sum(1 for g in grouped_diagnostics if source.EXPECTED in g.sources) return Statistics( true_positives=num_errors, false_positives=0, false_negatives=0 ) def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics: - if (source in grouped.sources) and (Source.EXPECTED in grouped.sources): + classification = grouped.classify(source) + if classification == Classification.TRUE_POSITIVE: statistics.true_positives += 1 - elif source in grouped.sources: + elif classification == Classification.FALSE_POSITIVE: statistics.false_positives += 1 - elif Source.EXPECTED in grouped.sources: + elif classification == Classification.FALSE_NEGATIVE: statistics.false_negatives += 1 statistics.total_diagnostics += len(grouped.diagnostics_by_source(source)) @@ -826,9 +756,7 @@ def main(): test_groups = get_test_groups(tests_dir) test_files = get_test_cases(test_groups, tests_dir / "tests") - expected_all = collect_expected_diagnostics(test_files) - expected_untagged, expected_tagged = split_expected_by_tag(expected_all) - tagged_locs = tagged_locations(expected_tagged) + expected = collect_expected_diagnostics(test_files) old = collect_ty_diagnostics( ty_path=args.old_ty, @@ -844,30 +772,12 @@ def main(): python_version=args.python_version, ) - old_untagged = filter_out_tagged( - old, - tagged_locs=tagged_locs, - ) - - new_untagged = filter_out_tagged( - new, - tagged_locs=tagged_locs, - ) - - grouped_tagged = condense_tagged_groups( - tagged_expected=expected_tagged, + grouped = group_diagnostics_by_key_or_tag( old=old, new=new, + expected=expected, ) - grouped_untagged = group_diagnostics_by_key( - old=old_untagged, - new=new_untagged, - expected=expected_untagged, - ) - - grouped = [*grouped_untagged, *grouped_tagged] - rendered = "\n\n".join( [ render_summary(grouped, force_summary_table=args.force_summary_table), From 5933ec16b5858b806b84c3f29492652ebf4194a4 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Wed, 21 Jan 2026 09:42:12 +0000 Subject: [PATCH 04/13] use distinct lines rather than the number of diagnostics --- scripts/conformance.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index d9558bd45830b..6d2abb33f54f7 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -262,9 +262,15 @@ def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: def classify(self, source: Source) -> Classification: if source in self.sources and Source.EXPECTED in self.sources: assert self.expected is not None - diagnostics = self.diagnostics_by_source(source) + distinct_lines = len( + { + diagnostic.location.positions.begin.line + for diagnostic in self.diagnostics_by_source(source) + } + ) expected_max = len(self.expected) if self.multi else 1 - if 1 <= len(diagnostics) <= expected_max: + + if 1 <= distinct_lines <= expected_max: return Classification.TRUE_POSITIVE else: return Classification.FALSE_POSITIVE From 8145bb302bbbb1b03d58cd87112850c754273d96 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Wed, 21 Jan 2026 09:46:21 +0000 Subject: [PATCH 05/13] use new classification --- scripts/conformance.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 6d2abb33f54f7..0071ff5ee626d 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -306,7 +306,8 @@ def _render_diff(self, diagnostics: list[Diagnostic], *, removed: bool = False): return "\n".join(f"{sign} {diagnostic}" for diagnostic in diagnostics) def display(self, format: Literal["diff", "github"]) -> str: - match self.classify(Source.NEW): + classification = self.classify(Source.NEW) + match classification: case Classification.TRUE_POSITIVE | Classification.FALSE_POSITIVE: assert self.new is not None return ( @@ -330,7 +331,7 @@ def display(self, format: Literal["diff", "github"]) -> str: ) case _: - raise ValueError(f"Unexpected classification: {self.classification}") + raise ValueError(f"Unexpected classification: {classification}") @dataclass(kw_only=True, slots=True) @@ -512,7 +513,9 @@ def render_grouped_diagnostics( ] get_change = attrgetter("change") - get_classification = attrgetter("classification") + + def get_classification(diag) -> Classification: + return diag.classify(Source.NEW) optional_diagnostics = sorted( (diag for diag in grouped if diag.optional), From 7d8f9666f06ee32616d5a9a2e4e3f80c28ffa17b Mon Sep 17 00:00:00 2001 From: Will Duke <41601410+WillDuke@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:19:53 +0000 Subject: [PATCH 06/13] format nested if-else expression Co-authored-by: Micha Reiser --- scripts/conformance.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 0071ff5ee626d..419d8132ed6bc 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -380,9 +380,11 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic] ), source=Source.EXPECTED, optional=error.group("optional") is not None, - tag=f"{file.name}:{error.group('tag')}" - if error.group("tag") - else None, + tag=( + f"{file.name}:{error.group('tag')}" + if error.group("tag") + else None + ), multi=error.group("multi") is not None, ) ) From 36a59ead37fef696466a93ce5e195d08f7c9fb32 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Wed, 21 Jan 2026 17:23:50 +0000 Subject: [PATCH 07/13] iterate through group once --- scripts/conformance.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 419d8132ed6bc..622ff2a599496 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -453,23 +453,31 @@ def group_diagnostics_by_key_or_tag( *expected, ] - # group diagnostics either by a path and a line or a path and a tag + # group diagnostics by a key which may be a path and a line or a path and a tag diagnostics = sorted(diagnostics, key=attrgetter("key")) grouped_diagnostics = [] for key, group in groupby(diagnostics, key=attrgetter("key")): - group = list(group) - old_group = list(filter(lambda diag: diag.source == Source.OLD, group)) - new_group = list(filter(lambda diag: diag.source == Source.NEW, group)) - expected_group = list( - filter(lambda diag: diag.source == Source.EXPECTED, group) - ) + old_diagnostics: list[Diagnostic] = [] + new_diagnostics: list[Diagnostic] = [] + expected_diagnostics: list[Diagnostic] = [] + sources: set[Source] = set() + + for diag in group: + sources.add(diag.source) + match diag.source: + case Source.OLD: + old_diagnostics.append(diag) + case Source.NEW: + new_diagnostics.append(diag) + case Source.EXPECTED: + expected_diagnostics.append(diag) grouped = GroupedDiagnostics( key=key, - sources={d.source for d in group}, - old=old_group, - new=new_group, - expected=expected_group, + sources=sources, + old=old_diagnostics, + new=new_diagnostics, + expected=expected_diagnostics, ) grouped_diagnostics.append(grouped) From 06b9da5db29ea9e0bc3cd51cee420af74cad2174 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Wed, 21 Jan 2026 17:26:47 +0000 Subject: [PATCH 08/13] remove unused code in compute_stats --- scripts/conformance.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 622ff2a599496..8ccd8d08bf0a7 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -486,13 +486,9 @@ def group_diagnostics_by_key_or_tag( def compute_stats( grouped_diagnostics: list[GroupedDiagnostics], - source: Source, + ty_version: Literal["new", "old"], ) -> Statistics: - if source == source.EXPECTED: - num_errors = sum(1 for g in grouped_diagnostics if source.EXPECTED in g.sources) - return Statistics( - true_positives=num_errors, false_positives=0, false_negatives=0 - ) + source = Source.NEW if ty_version == "new" else Source.OLD def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics: classification = grouped.classify(source) @@ -611,8 +607,8 @@ def format_metric(diff: float, old: float, new: float): return f"decreased from {old:.2%} to {new:.2%}" return f"held steady at {old:.2%}" - old = compute_stats(grouped_diagnostics, source=Source.OLD) - new = compute_stats(grouped_diagnostics, source=Source.NEW) + old = compute_stats(grouped_diagnostics, ty_version="old") + new = compute_stats(grouped_diagnostics, ty_version="new") assert new.true_positives > 0, ( "Expected ty to have at least one true positive.\n" From 5755034b390e4c37a7a227ed2768248fe3208952 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Wed, 21 Jan 2026 17:35:11 +0000 Subject: [PATCH 09/13] remove None from diagnostic types --- scripts/conformance.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 8ccd8d08bf0a7..a9385dfa5ce47 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -223,9 +223,9 @@ def severity_for_display(self) -> str: class GroupedDiagnostics: key: str sources: AbstractSet[Source] - old: list[Diagnostic] | None - new: list[Diagnostic] | None - expected: list[Diagnostic] | None + old: list[Diagnostic] + new: list[Diagnostic] + expected: list[Diagnostic] @property def change(self) -> Change: @@ -251,11 +251,11 @@ def multi(self) -> bool: def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: match source: case Source.NEW: - return self.new or [] + return self.new case Source.OLD: - return self.old or [] + return self.old case Source.EXPECTED: - return self.expected or [] + return self.expected case _: raise ValueError(f"Invalid source: {source}") @@ -317,12 +317,7 @@ def display(self, format: Literal["diff", "github"]) -> str: ) case Classification.FALSE_NEGATIVE | Classification.TRUE_NEGATIVE: - diagnostics = list( - filter( - lambda d: d is not None, - (*(self.old or []), *(self.expected or [])), - ) - ) + diagnostics = self.old if self.old else self.expected return ( self._render_diff(diagnostics, removed=True) @@ -381,10 +376,10 @@ def collect_expected_diagnostics(test_files: Sequence[Path]) -> list[Diagnostic] source=Source.EXPECTED, optional=error.group("optional") is not None, tag=( - f"{file.name}:{error.group('tag')}" - if error.group("tag") - else None - ), + f"{file.name}:{error.group('tag')}" + if error.group("tag") + else None + ), multi=error.group("multi") is not None, ) ) @@ -502,9 +497,11 @@ def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics statistics.total_diagnostics += len(grouped.diagnostics_by_source(source)) return statistics - grouped_diagnostics = [diag for diag in grouped_diagnostics if not diag.optional] + non_optional_diagnostics = [ + diag for diag in grouped_diagnostics if not diag.optional + ] - return reduce(increment, grouped_diagnostics, Statistics()) + return reduce(increment, non_optional_diagnostics, Statistics()) def render_grouped_diagnostics( From 5fdaa6bc291db56f16dc900cf44f9083320a34b3 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Wed, 21 Jan 2026 17:37:50 +0000 Subject: [PATCH 10/13] use a match in compute_stats --- scripts/conformance.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index a9385dfa5ce47..8dc898d930fb4 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -486,13 +486,13 @@ def compute_stats( source = Source.NEW if ty_version == "new" else Source.OLD def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics: - classification = grouped.classify(source) - if classification == Classification.TRUE_POSITIVE: - statistics.true_positives += 1 - elif classification == Classification.FALSE_POSITIVE: - statistics.false_positives += 1 - elif classification == Classification.FALSE_NEGATIVE: - statistics.false_negatives += 1 + match grouped.classify(source): + case Classification.TRUE_POSITIVE: + statistics.true_positives += 1 + case Classification.FALSE_POSITIVE: + statistics.false_positives += 1 + case Classification.FALSE_NEGATIVE: + statistics.false_negatives += 1 statistics.total_diagnostics += len(grouped.diagnostics_by_source(source)) return statistics From 1e9d33d739a7960ec41b0d435169c956306fa762 Mon Sep 17 00:00:00 2001 From: Will Duke Date: Wed, 21 Jan 2026 23:16:44 +0000 Subject: [PATCH 11/13] improve readability of classify --- scripts/conformance.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 8dc898d930fb4..6ef799bfc8438 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -260,22 +260,25 @@ def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: raise ValueError(f"Invalid source: {source}") def classify(self, source: Source) -> Classification: - if source in self.sources and Source.EXPECTED in self.sources: - assert self.expected is not None - distinct_lines = len( - { - diagnostic.location.positions.begin.line - for diagnostic in self.diagnostics_by_source(source) - } - ) - expected_max = len(self.expected) if self.multi else 1 + if source in self.sources: + if Source.EXPECTED in self.sources: + assert self.expected is not None + distinct_lines = len( + { + diagnostic.location.positions.begin.line + for diagnostic in self.diagnostics_by_source(source) + } + ) + expected_max = len(self.expected) if self.multi else 1 - if 1 <= distinct_lines <= expected_max: - return Classification.TRUE_POSITIVE + if 1 <= distinct_lines <= expected_max: + return Classification.TRUE_POSITIVE + else: + return Classification.FALSE_POSITIVE else: return Classification.FALSE_POSITIVE - elif source in self.sources and Source.EXPECTED not in self.sources: + elif source in self.sources: return Classification.FALSE_POSITIVE elif Source.EXPECTED in self.sources: From b158965a0d29afdb5a02cc8b9dd3c980daf1945c Mon Sep 17 00:00:00 2001 From: Will Duke Date: Thu, 22 Jan 2026 21:36:33 +0000 Subject: [PATCH 12/13] count diagnostics properly --- scripts/conformance.py | 105 ++++++++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 37 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 6ef799bfc8438..25ed4ae2b04c2 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -55,8 +55,6 @@ # on a set of lines with a matching tag # 4. Tagged multi-errors (E[tag+]): The type checker should raise one or # more errors on a set of lines with a matching tag -# This regex pattern parses the error lines in the conformance tests, -# but the following implementation currently ignores error tags. CONFORMANCE_ERROR_PATTERN = re.compile( r""" \#\s*E # "# E" begins each error @@ -105,6 +103,15 @@ def into_title(self) -> str: return "True positives removed" +@dataclass(kw_only=True, slots=True) +class Evaluation: + classification: Classification + true_positives: int = 0 + false_positives: int = 0 + true_negatives: int = 0 + false_negatives: int = 0 + + class Change(StrEnum): ADDED = auto() REMOVED = auto() @@ -157,7 +164,7 @@ class Diagnostic: optional: bool # tag identifying an error that can occur on multiple lines tag: str | None - # True if the error can occur on multiple lines or only once per tag + # True if one or more errors can occur on lines with the same tag multi: bool def __post_init__(self, *args, **kwargs) -> None: @@ -204,7 +211,7 @@ def from_gitlab_output( @property def key(self) -> str: - """Key to group diagnostics by path and beginning line.""" + """Key to group diagnostics by path and beginning line or path and tag.""" return ( f"{self.location.path.as_posix()}:{self.location.positions.begin.line}" if self.tag is None @@ -256,36 +263,72 @@ def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: return self.old case Source.EXPECTED: return self.expected - case _: - raise ValueError(f"Invalid source: {source}") - def classify(self, source: Source) -> Classification: + def classify(self, source: Source) -> Evaluation: + diagnostics = self.diagnostics_by_source(source) if source in self.sources: if Source.EXPECTED in self.sources: - assert self.expected is not None distinct_lines = len( { diagnostic.location.positions.begin.line - for diagnostic in self.diagnostics_by_source(source) + for diagnostic in diagnostics } ) expected_max = len(self.expected) if self.multi else 1 if 1 <= distinct_lines <= expected_max: - return Classification.TRUE_POSITIVE + return Evaluation( + classification=Classification.TRUE_POSITIVE, + true_positives=len(diagnostics), + false_positives=0, + true_negatives=0, + false_negatives=0, + ) else: - return Classification.FALSE_POSITIVE + # We select the line with the most diagnostics + # as our true positive, while the rest are false positives + max_line = max( + groupby( + diagnostics, key=lambda d: d.location.positions.begin.line + ), + key=lambda x: len(x[1]), + ) + remaining = len(diagnostics) - max_line + # We can never exceed the number of distinct lines + # if the diagnostic is multi, so we ignore that case + return Evaluation( + classification=Classification.FALSE_POSITIVE, + true_positives=max_line, + false_positives=remaining, + true_negatives=0, + false_negatives=0, + ) else: - return Classification.FALSE_POSITIVE - - elif source in self.sources: - return Classification.FALSE_POSITIVE + return Evaluation( + classification=Classification.FALSE_POSITIVE, + true_positives=0, + false_positives=len(diagnostics), + true_negatives=0, + false_negatives=0, + ) elif Source.EXPECTED in self.sources: - return Classification.FALSE_NEGATIVE + return Evaluation( + classification=Classification.FALSE_NEGATIVE, + true_positives=0, + false_positives=0, + true_negatives=0, + false_negatives=1, + ) else: - return Classification.TRUE_NEGATIVE + return Evaluation( + classification=Classification.TRUE_NEGATIVE, + true_positives=0, + false_positives=0, + true_negatives=1, + false_negatives=0, + ) def _render_row(self, diagnostics: list[Diagnostic]): locs = [] @@ -309,8 +352,8 @@ def _render_diff(self, diagnostics: list[Diagnostic], *, removed: bool = False): return "\n".join(f"{sign} {diagnostic}" for diagnostic in diagnostics) def display(self, format: Literal["diff", "github"]) -> str: - classification = self.classify(Source.NEW) - match classification: + eval = self.classify(Source.NEW) + match eval.classification: case Classification.TRUE_POSITIVE | Classification.FALSE_POSITIVE: assert self.new is not None return ( @@ -328,9 +371,6 @@ def display(self, format: Literal["diff", "github"]) -> str: else self._render_row(diagnostics) ) - case _: - raise ValueError(f"Unexpected classification: {classification}") - @dataclass(kw_only=True, slots=True) class Statistics: @@ -435,12 +475,7 @@ def group_diagnostics_by_key_or_tag( if d.tag is not None } - for diag in old: - diag.tag = tagged_lines.get( - (diag.location.path.name, diag.location.positions.begin.line), None - ) - - for diag in new: + for diag in chain(old, new): diag.tag = tagged_lines.get( (diag.location.path.name, diag.location.positions.begin.line), None ) @@ -489,14 +524,10 @@ def compute_stats( source = Source.NEW if ty_version == "new" else Source.OLD def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics: - match grouped.classify(source): - case Classification.TRUE_POSITIVE: - statistics.true_positives += 1 - case Classification.FALSE_POSITIVE: - statistics.false_positives += 1 - case Classification.FALSE_NEGATIVE: - statistics.false_negatives += 1 - + eval = grouped.classify(source) + statistics.true_positives += eval.true_positives + statistics.false_positives += eval.false_positives + statistics.false_negatives += eval.false_negatives statistics.total_diagnostics += len(grouped.diagnostics_by_source(source)) return statistics @@ -521,7 +552,7 @@ def render_grouped_diagnostics( get_change = attrgetter("change") def get_classification(diag) -> Classification: - return diag.classify(Source.NEW) + return diag.classify(Source.NEW).classification optional_diagnostics = sorted( (diag for diag in grouped if diag.optional), From 6e88957d56b91befbc384f7e235813d300d512ff Mon Sep 17 00:00:00 2001 From: Will Duke Date: Thu, 22 Jan 2026 21:50:09 +0000 Subject: [PATCH 13/13] count optional diagnostics --- scripts/conformance.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index 25ed4ae2b04c2..81da51fe68887 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -266,7 +266,17 @@ def diagnostics_by_source(self, source: Source) -> list[Diagnostic]: def classify(self, source: Source) -> Evaluation: diagnostics = self.diagnostics_by_source(source) + if source in self.sources: + if self.optional: + return Evaluation( + classification=Classification.TRUE_POSITIVE, + true_positives=len(diagnostics), + false_positives=0, + true_negatives=0, + false_negatives=0, + ) + if Source.EXPECTED in self.sources: distinct_lines = len( { @@ -313,6 +323,14 @@ def classify(self, source: Source) -> Evaluation: ) elif Source.EXPECTED in self.sources: + if self.optional: + return Evaluation( + classification=Classification.TRUE_NEGATIVE, + true_positives=0, + false_positives=0, + true_negatives=len(diagnostics), + false_negatives=0, + ) return Evaluation( classification=Classification.FALSE_NEGATIVE, true_positives=0, @@ -463,7 +481,7 @@ def collect_ty_diagnostics( ] -def group_diagnostics_by_key_or_tag( +def group_diagnostics_by_key( old: list[Diagnostic], new: list[Diagnostic], expected: list[Diagnostic], @@ -486,7 +504,6 @@ def group_diagnostics_by_key_or_tag( *expected, ] - # group diagnostics by a key which may be a path and a line or a path and a tag diagnostics = sorted(diagnostics, key=attrgetter("key")) grouped_diagnostics = [] for key, group in groupby(diagnostics, key=attrgetter("key")): @@ -531,11 +548,7 @@ def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics statistics.total_diagnostics += len(grouped.diagnostics_by_source(source)) return statistics - non_optional_diagnostics = [ - diag for diag in grouped_diagnostics if not diag.optional - ] - - return reduce(increment, non_optional_diagnostics, Statistics()) + return reduce(increment, grouped_diagnostics, Statistics()) def render_grouped_diagnostics( @@ -818,7 +831,7 @@ def main(): python_version=args.python_version, ) - grouped = group_diagnostics_by_key_or_tag( + grouped = group_diagnostics_by_key( old=old, new=new, expected=expected,