diff --git a/services/comparison/__init__.py b/services/comparison/__init__.py index d8d4a7af9..bf43f0729 100644 --- a/services/comparison/__init__.py +++ b/services/comparison/__init__.py @@ -1,13 +1,14 @@ import asyncio import logging from dataclasses import dataclass -from typing import List, Optional +from typing import Dict, List, Optional from shared.reports.changes import get_changes_using_rust, run_comparison_using_rust from shared.reports.types import Change from shared.torngit.exceptions import ( TorngitClientGeneralError, ) +from shared.utils.sessions import SessionType from database.enums import CompareCommitState, TestResultsProcessingError from database.models import CompareCommit @@ -15,7 +16,7 @@ from services.archive import ArchiveService from services.comparison.changes import get_changes from services.comparison.overlays import get_overlay -from services.comparison.types import Comparison, FullCommit +from services.comparison.types import Comparison, FullCommit, ReportUploadedCount from services.repository import get_repo_provider_service log = logging.getLogger(__name__) @@ -46,6 +47,7 @@ class ComparisonProxy(object): Attributes: comparison (Comparison): The original comparison we want to wrap and proxy + context (ComparisonContext | None): Other information not coverage-related that may affect notifications """ def __init__( @@ -66,6 +68,7 @@ def __init__( self._archive_service = None self._overlays = {} self.context = context or ComparisonContext() + self._cached_reports_uploaded_per_flag: List[ReportUploadedCount] | None = None def get_archive_service(self): if self._archive_service is None: @@ -280,6 +283,63 @@ async def get_impacted_files(self): files_in_diff, ) + def get_reports_uploaded_count_per_flag(self) -> List[ReportUploadedCount]: + """This function counts how many reports (by flag) the BASE and HEAD commit have.""" + if self._cached_reports_uploaded_per_flag: + # Reports may have many sessions, so it's useful to memoize this function + return self._cached_reports_uploaded_per_flag + if not self.has_head_report() or not self.has_project_coverage_base_report(): + log.warning( + "Can't calculate diff in uploads. Missing some report", + extra=dict( + has_head_report=self.has_head_report(), + has_project_base_report=self.has_project_coverage_base_report(), + ), + ) + return [] + per_flag_dict: Dict[str, ReportUploadedCount] = dict() + base_report = self.comparison.project_coverage_base.report + head_report = self.comparison.head.report + ops = [(base_report, "base_count"), (head_report, "head_count")] + for curr_report, curr_counter in ops: + for session in curr_report.sessions: + # We ignore carryforward sessions + # Because not all commits would upload all flags (potentially) + # But they are still carried forward + if session.session_type != SessionType.carriedforward: + if session.flags == []: + session.flags = [""] + for flag in session.flags: + dict_value = per_flag_dict.get(flag) + if dict_value is None: + dict_value = ReportUploadedCount( + flag=flag, base_count=0, head_count=0 + ) + dict_value[curr_counter] += 1 + per_flag_dict[flag] = dict_value + self._cached_reports_uploaded_per_flag = list(per_flag_dict.values()) + return self._cached_reports_uploaded_per_flag + + def get_reports_uploaded_count_per_flag_diff(self) -> List[ReportUploadedCount]: + """ + Returns the difference, per flag, or reports uploaded in BASE and HEAD + + ❗️ For a difference to be considered there must be at least 1 "uploaded" upload in both + BASE and HEAD (that is, if all reports for a flag are "carryforward" it's not considered a diff) + """ + reports_per_flag = self.get_reports_uploaded_count_per_flag() + + def is_valid_diff(obj: ReportUploadedCount): + return ( + obj["base_count"] > 0 + and obj["head_count"] > 0 + and obj["base_count"] != obj["head_count"] + ) + + per_flag_diff = list(filter(is_valid_diff, reports_per_flag)) + self._cached_reports_uploaded_per_flag = per_flag_diff + return per_flag_diff + class FilteredComparison(object): def __init__(self, real_comparison: ComparisonProxy, *, flags, path_patterns): diff --git a/services/comparison/tests/unit/test_reports_uploaded_count_diff.py b/services/comparison/tests/unit/test_reports_uploaded_count_diff.py new file mode 100644 index 000000000..5e31b86df --- /dev/null +++ b/services/comparison/tests/unit/test_reports_uploaded_count_diff.py @@ -0,0 +1,117 @@ +from unittest.mock import MagicMock + +import pytest +from shared.reports.resources import Report +from shared.utils.sessions import Session, SessionType + +from services.comparison import ComparisonProxy +from services.comparison.types import Comparison, FullCommit, ReportUploadedCount + + +@pytest.mark.parametrize( + "head_sessions, base_sessions, expected_count, expected_diff", + [ + ( + [ + Session( + flags=["unit", "local"], session_type=SessionType.carriedforward + ), + Session(flags=["integration"], session_type=SessionType.uploaded), + Session(flags=["unit"], session_type=SessionType.uploaded), + Session(flags=["unit"], session_type=SessionType.uploaded), + Session(flags=["integration"], session_type=SessionType.uploaded), + Session(flags=[], session_type=SessionType.uploaded), + ], + [ + Session( + flags=["unit", "local"], session_type=SessionType.carriedforward + ), + Session(flags=["integration"], session_type=SessionType.carriedforward), + Session(flags=["unit"], session_type=SessionType.uploaded), + Session(flags=["unit"], session_type=SessionType.uploaded), + ], + [ + ReportUploadedCount(flag="unit", base_count=2, head_count=2), + ReportUploadedCount(flag="integration", base_count=0, head_count=2), + ReportUploadedCount(flag="", base_count=0, head_count=1), + ], + [], + ), + ( + [ + Session( + flags=["unit", "local"], session_type=SessionType.carriedforward + ), + Session(flags=["integration"], session_type=SessionType.uploaded), + Session(flags=["unit"], session_type=SessionType.uploaded), + Session(flags=["unit"], session_type=SessionType.uploaded), + Session(flags=["integration"], session_type=SessionType.uploaded), + Session(flags=[""], session_type=SessionType.uploaded), + ], + [ + Session(flags=["unit", "local"], session_type=SessionType.uploaded), + Session(flags=["integration"], session_type=SessionType.uploaded), + Session(flags=["unit"], session_type=SessionType.uploaded), + Session(flags=["unit"], session_type=SessionType.uploaded), + Session(flags=["obscure_flag"], session_type=SessionType.uploaded), + ], + [ + ReportUploadedCount(flag="unit", base_count=3, head_count=2), + ReportUploadedCount(flag="local", base_count=1, head_count=0), + ReportUploadedCount(flag="integration", base_count=1, head_count=2), + ReportUploadedCount(flag="obscure_flag", base_count=1, head_count=0), + ReportUploadedCount(flag="", base_count=0, head_count=1), + ], + [ + ReportUploadedCount(flag="unit", base_count=3, head_count=2), + ReportUploadedCount(flag="integration", base_count=1, head_count=2), + ], + ), + ], + ids=["flag_counts_no_diff", "flag_count_yes_diff"], +) +def test_get_reports_uploaded_count_per_flag( + head_sessions, base_sessions, expected_count, expected_diff +): + head_report = Report() + head_report.sessions = head_sessions + base_report = Report() + base_report.sessions = base_sessions + comparison_proxy = ComparisonProxy( + comparison=Comparison( + head=FullCommit(report=head_report, commit=None), + project_coverage_base=FullCommit(report=base_report, commit=None), + patch_coverage_base_commitid=None, + enriched_pull=None, + ) + ) + # Python Dicts preserve order, so we can actually test this equality + # See more https://stackoverflow.com/a/39537308 + assert comparison_proxy.get_reports_uploaded_count_per_flag() == expected_count + assert comparison_proxy.get_reports_uploaded_count_per_flag_diff() == expected_diff + + +def test_get_reports_uploaded_count_per_flag_cached(): + comparison_proxy = ComparisonProxy(comparison=MagicMock(name="fake_comparison")) + comparison_proxy._cached_reports_uploaded_per_flag = ( + "object_that_doesnt_have_this_shape" + ) + assert ( + comparison_proxy.get_reports_uploaded_count_per_flag() + == "object_that_doesnt_have_this_shape" + ) + + +def test_get_reports_uploaded_count_per_flag_diff_missing_report(): + head_report = None + base_report = Report() + base_report.sessions = None + comparison_proxy = ComparisonProxy( + comparison=Comparison( + head=FullCommit(report=head_report, commit=None), + project_coverage_base=FullCommit(report=base_report, commit=None), + patch_coverage_base_commitid=None, + enriched_pull=None, + ) + ) + assert comparison_proxy.get_reports_uploaded_count_per_flag_diff() == [] diff --git a/services/comparison/types.py b/services/comparison/types.py index 4e71d7676..603dab364 100644 --- a/services/comparison/types.py +++ b/services/comparison/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, TypedDict from shared.reports.resources import Report from shared.yaml import UserYaml @@ -14,6 +14,12 @@ class FullCommit(object): report: Report +class ReportUploadedCount(TypedDict): + flag: str = "" + base_count: int = 0 + head_count: int = 0 + + @dataclass class Comparison(object): head: FullCommit