Skip to content

Commit

Permalink
fix: use original PR base to compute patch coverage (#199)
Browse files Browse the repository at this point in the history
* fix: use original PR base to compute patch coverage

* gate to team plan + nullcheck/add comments

* fix clumsy test errors
  • Loading branch information
matt-codecov committed Dec 5, 2023
1 parent 82832af commit a404e8e
Show file tree
Hide file tree
Showing 17 changed files with 272 additions and 28 deletions.
56 changes: 43 additions & 13 deletions services/comparison/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class ComparisonProxy(object):
def __init__(self, comparison: Comparison):
self.comparison = comparison
self._repository_service = None
self._diff = None
self._adjusted_base_diff = None
self._original_base_diff = None
self._changes = None
self._existing_statuses = None
self._behind_by = None
Expand Down Expand Up @@ -100,18 +101,47 @@ def enriched_pull(self):
def pull(self):
return self.comparison.pull

async def get_diff(self):
async def get_diff(self, use_original_base=False):
async with self._diff_lock:
if self._diff is None:
head = self.comparison.head.commit
base = self.comparison.base.commit
if base is None:
head = self.comparison.head.commit
base = self.comparison.base.commit
original_base_commitid = self.comparison.original_base_commitid

# If the original and adjusted bases are the same commit, then if we
# already fetched the diff for one we can return it for the other.
bases_match = original_base_commitid == (base.commitid if base else "")

populate_original_base_diff = use_original_base and (
not self._original_base_diff
)
populate_adjusted_base_diff = (not use_original_base) and (
not self._adjusted_base_diff
)
if populate_original_base_diff:
if bases_match and self._adjusted_base_diff:
self._original_base_diff = self._adjusted_base_diff
elif original_base_commitid is not None:
pull_diff = await self.repository_service.get_compare(
original_base_commitid, head.commitid, with_commits=False
)
self._original_base_diff = pull_diff["diff"]
else:
return None
pull_diff = await self.repository_service.get_compare(
base.commitid, head.commitid, with_commits=False
)
self._diff = pull_diff["diff"]
return self._diff
elif populate_adjusted_base_diff:
if bases_match and self._original_base_diff:
self._adjusted_base_diff = self._original_base_diff
elif base is not None:
pull_diff = await self.repository_service.get_compare(
base.commitid, head.commitid, with_commits=False
)
self._adjusted_base_diff = pull_diff["diff"]
else:
return None

if use_original_base:
return self._original_base_diff
else:
return self._adjusted_base_diff

async def get_changes(self) -> Optional[List[Change]]:
# Just make sure to not cause a deadlock between this and get_diff
Expand Down Expand Up @@ -247,8 +277,8 @@ def __init__(self, real_comparison: ComparisonProxy, *, flags, path_patterns):
async def get_impacted_files(self):
return await self.real_comparison.get_impacted_files()

async def get_diff(self):
return await self.real_comparison.get_diff()
async def get_diff(self, use_original_base=False):
return await self.real_comparison.get_diff(use_original_base=use_original_base)

async def get_existing_statuses(self):
return await self.real_comparison.get_existing_statuses()
Expand Down
1 change: 1 addition & 0 deletions services/comparison/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def sample_comparison(dbsession, request, sample_report):
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down
148 changes: 148 additions & 0 deletions services/comparison/tests/unit/test_comparison_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import pytest
from mock import call, patch

from database.tests.factories import CommitFactory, PullFactory, RepositoryFactory
from services.comparison import ComparisonProxy
from services.comparison.types import Comparison, FullCommit
from services.repository import EnrichedPull


def make_sample_comparison(adjusted_base=False):
repo = RepositoryFactory.create(owner__service="github")

head_commit = CommitFactory.create(repository=repo)
adjusted_base_commit = CommitFactory.create(repository=repo)

if adjusted_base:
# Just getting a random commitid, doesn't need to be in the db
original_base_commitid = CommitFactory.create(repository=repo).commitid
else:
original_base_commitid = adjusted_base_commit.commitid

pull = PullFactory.create(
repository=repo,
head=head_commit.commitid,
base=original_base_commitid,
compared_to=adjusted_base_commit.commitid,
)

base_full_commit = FullCommit(commit=adjusted_base_commit, report=None)
head_full_commit = FullCommit(commit=head_commit, report=None)
return ComparisonProxy(
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=original_base_commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={},
),
),
)


class TestComparisonProxy(object):

compare_url = "https://api.github.com/repos/{}/compare/{}...{}"

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_adjusted_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=True)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=False)

assert result == "magic string"
assert comparison._adjusted_base_diff == "magic string"
assert not comparison._original_base_diff
assert (
comparison.comparison.original_base_commitid
!= comparison.base.commit.commitid
)

assert mock_get_compare.call_args_list == [
call(
comparison.base.commit.commitid,
comparison.head.commit.commitid,
with_commits=False,
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_original_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=True)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=True)

assert result == "magic string"
assert comparison._original_base_diff == "magic string"
assert not comparison._adjusted_base_diff
assert (
comparison.comparison.original_base_commitid
!= comparison.base.commit.commitid
)

assert mock_get_compare.call_args_list == [
call(
comparison.comparison.original_base_commitid,
comparison.head.commit.commitid,
with_commits=False,
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_bases_match_original_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=False)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=True)

assert result == "magic string"
assert comparison._original_base_diff == "magic string"
assert (
comparison.comparison.original_base_commitid
== comparison.base.commit.commitid
)

# In this test case, the adjusted and original base commits are the
# same. If we get one, we should set the cache for the other.
adjusted_base_result = await comparison.get_diff(use_original_base=False)
assert comparison._adjusted_base_diff == "magic string"

# Make sure we only called the Git provider API once
assert mock_get_compare.call_args_list == [
call(
comparison.comparison.original_base_commitid,
comparison.head.commit.commitid,
with_commits=False,
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_bases_match_adjusted_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=False)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=False)

assert result == "magic string"
assert comparison._adjusted_base_diff == "magic string"
assert (
comparison.comparison.original_base_commitid
== comparison.base.commit.commitid
)

# In this test case, the adjusted and original base commits are the
# same. If we get one, we should set the cache for the other.
adjusted_base_result = await comparison.get_diff(use_original_base=True)
assert comparison._adjusted_base_diff == "magic string"

# Make sure we only called the Git provider API once
assert mock_get_compare.call_args_list == [
call(
comparison.comparison.original_base_commitid,
comparison.head.commit.commitid,
with_commits=False,
),
]
12 changes: 12 additions & 0 deletions services/comparison/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,19 @@ class FullCommit(object):
@dataclass
class Comparison(object):
head: FullCommit

# To see how a patch changes project coverage, we compare the branch head's
# report against the base's report, or if the base isn't in our database,
# the next-oldest commit that is. Be aware that this base commit may not be
# the true base that, for example, a PR is based on.
base: FullCommit

# Computing patch coverage doesn't require an old report to compare against,
# so doing the "next-oldest" adjustment described above is unnecessary and
# makes the results less correct. All it requires is a head report and the
# patch diff, and the original base's commit SHA is enough to get that.
original_base_commitid: str

enriched_pull: EnrichedPull
current_yaml: Optional[UserYaml] = None

Expand Down
3 changes: 0 additions & 3 deletions services/notification/notifiers/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,6 @@ async def notify(self, comparison: Comparison):
data_sent=payload,
)

async def get_diff(self, comparison: Comparison):
return await comparison.get_diff()

def get_line_diff(self, file_diff):
"""
This method traverses a git file diff and returns the lines (as line numbers) that where chnaged
Expand Down
2 changes: 1 addition & 1 deletion services/notification/notifiers/checks/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def build_payload(self, comparison: Comparison):
"summary": "\n\n".join([codecov_link, message]),
},
}
diff = await self.get_diff(comparison)
diff = await comparison.get_diff(use_original_base=True)
# TODO: Look into why the apply diff in get_patch_status is not saving state at this point
comparison.head.report.apply_diff(diff)
annotations = self.create_annotations(comparison, diff)
Expand Down
2 changes: 1 addition & 1 deletion services/notification/notifiers/mixins/message/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def create_message(
"""
changes = await comparison.get_changes()
diff = await comparison.get_diff()
diff = await comparison.get_diff(use_original_base=True)
behind_by = await comparison.get_behind_by()
base_report = comparison.base.report
head_report = comparison.head.report
Expand Down
4 changes: 2 additions & 2 deletions services/notification/notifiers/mixins/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class StatusPatchMixin(object):
async def get_patch_status(self, comparison) -> Tuple[str, str]:
threshold = Decimal(self.notifier_yaml_settings.get("threshold") or "0.0")
diff = await self.get_diff(comparison)
diff = await comparison.get_diff(use_original_base=True)
totals = comparison.head.report.apply_diff(diff)
if self.notifier_yaml_settings.get("target") not in ("auto", None):
target_coverage = Decimal(
Expand Down Expand Up @@ -233,7 +233,7 @@ async def _apply_fully_covered_patch_behavior(
extra=dict(commit=comparison.head.commit.commitid),
)
return None
diff = await self.get_diff(comparison)
diff = await comparison.get_diff(use_original_base=True)
patch_totals = comparison.head.report.apply_diff(diff)
if patch_totals is None or patch_totals.lines == 0:
# Coverage was not changed by patch
Expand Down
3 changes: 0 additions & 3 deletions services/notification/notifiers/status/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ def flag_coverage_was_uploaded(self, comparison) -> bool:
len(report_uploaded_flags.intersection(flags_included_in_status_check)) > 0
)

async def get_diff(self, comparison: Comparison):
return await comparison.get_diff()

@property
def repository_service(self):
if not self._repository_service:
Expand Down
11 changes: 11 additions & 0 deletions services/notification/notifiers/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def _comparison(service="github", username="codecov-test"):
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(database_pull=pull, provider_pull={}),
)
)
Expand Down Expand Up @@ -411,6 +412,7 @@ def sample_comparison(dbsession, request, sample_report, mocker):
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down Expand Up @@ -468,6 +470,7 @@ async def sample_comparison_coverage_carriedforward(
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down Expand Up @@ -521,6 +524,7 @@ def sample_comparison_negative_change(dbsession, request, sample_report, mocker)
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down Expand Up @@ -574,6 +578,7 @@ def sample_comparison_no_change(dbsession, request, sample_report, mocker):
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down Expand Up @@ -625,6 +630,7 @@ def sample_comparison_without_pull(dbsession, request, sample_report, mocker):
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(database_pull=None, provider_pull=None),
)
)
Expand Down Expand Up @@ -665,6 +671,7 @@ def sample_comparison_database_pull_without_provider(
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(database_pull=pull, provider_pull=None),
)
)
Expand Down Expand Up @@ -697,6 +704,7 @@ def generate_sample_comparison(username, dbsession, base_report, head_report):
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down Expand Up @@ -749,6 +757,7 @@ def sample_comparison_without_base_report(dbsession, request, sample_report, moc
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down Expand Up @@ -806,6 +815,7 @@ def sample_comparison_without_base_with_pull(dbsession, request, sample_report,
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid="cdf9aa4bd2c6bcd8a662864097cb62a85a2fd55b",
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down Expand Up @@ -872,6 +882,7 @@ def sample_comparison_head_and_pull_head_differ(
Comparison(
head=head_full_commit,
base=base_full_commit,
original_base_commitid=base_commit.commitid,
enriched_pull=EnrichedPull(
database_pull=pull,
provider_pull={
Expand Down
Loading

0 comments on commit a404e8e

Please sign in to comment.