Skip to content

Commit

Permalink
dev: Optimizations / Dead code removal for comparison class (#725)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajay-sentry authored Aug 2, 2024
1 parent 47e6cee commit e245d32
Show file tree
Hide file tree
Showing 13 changed files with 35 additions and 65 deletions.
1 change: 0 additions & 1 deletion api/internal/tests/unit/views/test_compare_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def build_commits(client):
return repo, commit_base, commit_head


@patch("services.comparison.Comparison.has_unmerged_base_commits", lambda self: False)
@patch("services.archive.ArchiveService.read_chunks", lambda obj, sha: "")
@patch(
"api.shared.repo.repository_accessors.RepoAccessors.get_repo_permissions",
Expand Down
6 changes: 2 additions & 4 deletions api/internal/tests/views/test_compare_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,12 @@ async def get_authenticated(self):
return False, False


@patch("services.comparison.Comparison.has_unmerged_base_commits", lambda self: True)
@patch("services.comparison.Comparison.head_report", new_callable=PropertyMock)
@patch("services.comparison.Comparison.base_report", new_callable=PropertyMock)
@patch("services.repo_providers.RepoProviderService.get_adapter")
class TestCompareViewSetRetrieve(APITestCase):
"""
Tests for retrieving a comparison. Does not test data that will be depracated,
Tests for retrieving a comparison. Does not test data that will be deprecated,
eg base and head report fields. Tests for commits etc will be added as the
compare-api refactor progresses.
"""
Expand Down Expand Up @@ -223,7 +222,6 @@ def test_returns_200_and_expected_files_on_success(

assert response.status_code == status.HTTP_200_OK
assert response.data["files"] == self.expected_files
assert response.data["has_unmerged_base_commits"] is True

def test_returns_404_if_base_or_head_references_not_found(
self, adapter_mock, base_report_mock, head_report_mock
Expand Down Expand Up @@ -335,7 +333,7 @@ def test_diffs_larger_than_MAX_DIFF_SIZE_doesnt_include_lines(

comparison.MAX_DIFF_SIZE = previous_max

def test_file_returns_comparefile_with_diff_and_src_data(
def test_file_returns_compare_file_with_diff_and_src_data(
self, adapter_mock, base_report_mock, head_report_mock
):
base_report_mock.return_value = self.base_report
Expand Down
6 changes: 2 additions & 4 deletions api/public/v2/tests/test_api_compare_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,12 @@ def __init__(self):
]


@patch("services.comparison.Comparison.has_unmerged_base_commits", lambda self: True)
@patch("services.comparison.Comparison.head_report", new_callable=PropertyMock)
@patch("services.comparison.Comparison.base_report", new_callable=PropertyMock)
@patch("services.repo_providers.RepoProviderService.get_adapter")
class TestCompareViewSetRetrieve(APITestCase):
"""
Tests for retrieving a comparison. Does not test data that will be depracated,
Tests for retrieving a comparison. Does not test data that will be deprecated,
eg base and head report fields. Tests for commits etc will be added as the
compare-api refactor progresses.
"""
Expand Down Expand Up @@ -412,7 +411,6 @@ def test_returns_200_and_expected_files_on_success(

assert response.status_code == status.HTTP_200_OK
assert response.data["files"] == self.expected_files
assert response.data["has_unmerged_base_commits"] is True

def test_returns_404_if_base_or_head_references_not_found(
self, adapter_mock, base_report_mock, head_report_mock
Expand Down Expand Up @@ -535,7 +533,7 @@ def test_pullid_with_nonexistent_head_returns_404(

assert response.status_code == status.HTTP_404_NOT_FOUND

def test_file_returns_comparefile_with_diff_and_src_data(
def test_file_returns_compare_file_with_diff_and_src_data(
self, adapter_mock, base_report_mock, head_report_mock
):
base_report_mock.return_value = self.base_report
Expand Down
1 change: 0 additions & 1 deletion api/shared/compare/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class ComparisonSerializer(serializers.Serializer):
diff = serializers.SerializerMethodField()
files = serializers.SerializerMethodField()
untracked = serializers.SerializerMethodField()
has_unmerged_base_commits = serializers.BooleanField()

def get_untracked(self, comparison) -> List[str]:
return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_update_user_when_agreement_is_false(self):
assert self.current_user.terms_agreement_at == self.updated_at

self.current_user.refresh_from_db()
self.current_user.email == before_refresh_business_email
assert self.current_user.email == before_refresh_business_email

@freeze_time("2022-01-01T00:00:00")
def test_update_user_when_agreement_is_true(self):
Expand All @@ -56,7 +56,7 @@ def test_update_user_when_agreement_is_true(self):
assert self.current_user.terms_agreement_at == self.updated_at

self.current_user.refresh_from_db()
self.current_user.email == before_refresh_business_email
assert self.current_user.email == before_refresh_business_email

@freeze_time("2022-01-01T00:00:00")
def test_update_owner_and_user_when_email_is_not_empty(self):
Expand All @@ -73,7 +73,7 @@ def test_update_owner_and_user_when_email_is_not_empty(self):
assert self.current_user.terms_agreement_at == self.updated_at

self.current_user.refresh_from_db()
self.current_user.email == "[email protected]"
assert self.current_user.email == "[email protected]"

def test_validation_error_when_terms_is_none(self):
with pytest.raises(ValidationError):
Expand Down
2 changes: 1 addition & 1 deletion graphql_api/tests/mutation/test_set_yaml_on_owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def test_mutation_dispatch_to_command(self, command_mock):
}
data = self.gql_request(query, owner=self.owner, variables={"input": input})
command_mock.assert_called_once_with(input["username"], input["yaml"])
data["setYamlOnOwner"]["owner"]["username"] == self.owner.username
assert data["setYamlOnOwner"]["owner"]["username"] == self.owner.username
2 changes: 1 addition & 1 deletion graphql_api/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_self_hosted_license_returns_null_if_invalid_license(self, license_mock)
)
assert data == {
"config": {
"selfHostedLicense": {"expirationDate": None},
"selfHostedLicense": None,
},
}

Expand Down
2 changes: 1 addition & 1 deletion graphql_api/types/comparison/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def resolve_has_different_number_of_head_and_base_reports(
comparison: ComparisonReport,
info: GraphQLResolveInfo,
**kwargs, # type: ignore
) -> False:
) -> bool:
# TODO: can we remove the need for `info.context["comparison"]` here?
if "comparison" not in info.context:
return False
Expand Down
2 changes: 1 addition & 1 deletion graphql_api/types/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def resolve_self_hosted_license(_, info):
license = self_hosted.get_current_license()

if not license.is_valid:
None
return None

return license

Expand Down
38 changes: 14 additions & 24 deletions services/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def get_file_comparison(self, file_name, with_src=False, bypass_max_diff=False):

@property
def git_comparison(self):
return self._fetch_comparison_and_reverse_comparison[0]
return self._fetch_comparison[0]

@cached_property
def base_report(self):
Expand All @@ -713,7 +713,11 @@ def head_report(self):
else:
raise e

report.apply_diff(self.git_comparison["diff"])
# Return the old report if the github API call fails for any reason
try:
report.apply_diff(self.git_comparison["diff"])
except Exception:
pass
return report

@cached_property
Expand Down Expand Up @@ -762,10 +766,9 @@ def upload_commits(self):
return commits_queryset

@cached_property
def _fetch_comparison_and_reverse_comparison(self):
def _fetch_comparison(self):
"""
Fetches comparison and reverse comparison concurrently, then
caches the result. Returns (comparison, reverse_comparison).
Fetches comparison, and caches the result.
"""
adapter = RepoProviderService().get_adapter(
self.user, self.base_commit.repository
Expand All @@ -774,12 +777,8 @@ def _fetch_comparison_and_reverse_comparison(self):
self.base_commit.commitid, self.head_commit.commitid
)

reverse_comparison_coro = adapter.get_compare(
self.head_commit.commitid, self.base_commit.commitid
)

async def runnable():
return await asyncio.gather(comparison_coro, reverse_comparison_coro)
return await asyncio.gather(comparison_coro)

return async_to_sync(runnable)()

Expand All @@ -791,18 +790,6 @@ def non_carried_forward_flags(self):
flags_dict = self.head_report.flags
return [flag for flag, vals in flags_dict.items() if not vals.carriedforward]

@cached_property
def has_unmerged_base_commits(self):
"""
We use reverse comparison to detect if any commits exist in the
base reference but not in the head reference. We use this information
to show a message in the UI urging the user to integrate the changes
in the base reference in order to see accurate coverage information.
We compare with 1 because torngit injects the base commit into the commits
array because reasons.
"""
return len(self._fetch_comparison_and_reverse_comparison[1]["commits"]) > 1


class FlagComparison(object):
def __init__(self, comparison, flag_name):
Expand Down Expand Up @@ -869,7 +856,7 @@ def has_diff(self) -> bool:
"""
Returns `True` if the file has any additions or removals in the diff
"""
return (
return bool(
self.added_diff_coverage
and len(self.added_diff_coverage) > 0
or self.removed_diff_coverage
Expand Down Expand Up @@ -953,7 +940,10 @@ def change_coverage(self) -> Optional[float]:
and self.head_coverage
and self.head_coverage.coverage
):
return float(self.head_coverage.coverage - self.base_coverage.coverage)
return float(
float(self.head_coverage.coverage or 0)
- float(self.base_coverage.coverage or 0)
)

@cached_property
def file_name(self) -> Optional[str]:
Expand Down
6 changes: 4 additions & 2 deletions services/tests/test_billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,9 +893,11 @@ def test_get_proration_params(self):
plan=PlanName.CODECOV_PRO_MONTHLY.value, plan_user_count=20
)
desired_plan = {"value": PlanName.SENTRY_MONTHLY.value, "quantity": 19}
self.stripe._get_proration_params(owner, desired_plan) == "none"
assert self.stripe._get_proration_params(owner, desired_plan) == "none"
desired_plan = {"value": PlanName.SENTRY_MONTHLY.value, "quantity": 20}
self.stripe._get_proration_params(owner, desired_plan) == "always_invoice"
assert (
self.stripe._get_proration_params(owner, desired_plan) == "always_invoice"
)
desired_plan = {"value": PlanName.SENTRY_MONTHLY.value, "quantity": 21}
assert (
self.stripe._get_proration_params(owner, desired_plan) == "always_invoice"
Expand Down
26 changes: 5 additions & 21 deletions services/tests/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def test_can_traverse_diff_with_line_numbers_greater_than_file_eof(self):
manager.apply([visitor])
assert visitor.line_numbers == [(1, 1), (2, 2), (3, None), (None, 3)]

def test_can_traverse_diff_with_difflike_lines(self):
def test_can_traverse_diff_with_diff_like_lines(self):
src = [
"- line 1", # not part of diff
"- line 2", # not part of diff
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_number_shows_none_for_base_if_plus_not_part_of_diff(self):
lc = LineComparison(None, [0, "", [], 0, 0], base_ln, head_ln, "+", False)
assert lc.number == {"base": base_ln, "head": head_ln}

def test_number_shows_none_for_base_if_minux_not_part_of_diff(self):
def test_number_shows_none_for_base_if_minus_not_part_of_diff(self):
base_ln = 3
head_ln = 4
lc = LineComparison(None, [0, "", [], 0, 0], base_ln, head_ln, "-", False)
Expand Down Expand Up @@ -542,7 +542,7 @@ def test_stats_returns_diff_stats_if_diff_data(self):
self.file_comparison.diff_data = {"stats": expected_stats}
assert self.file_comparison.stats == expected_stats

def test_lines_returns_emptylist_if_no_diff_or_src(self):
def test_lines_returns_empty_list_if_no_diff_or_src(self):
assert self.file_comparison.lines == []

# essentially a smoke/integration test
Expand Down Expand Up @@ -1154,11 +1154,11 @@ def test_allow_coverage_offsets(self, get_config_mock):
with self.subTest("returns app settings value if exists, True if not"):
get_config_mock.return_value = True
comparison = PullRequestComparison(owner, pull)
comparison.allow_coverage_offsets is True
assert comparison.allow_coverage_offsets is True

get_config_mock.return_value = False
comparison = PullRequestComparison(owner, pull)
comparison.allow_coverage_offsets is False
assert comparison.allow_coverage_offsets is False

@patch("services.repo_providers.RepoProviderService.get_adapter")
def test_pseudo_diff_returns_diff_between_base_and_compared_to(
Expand Down Expand Up @@ -1318,22 +1318,6 @@ def setUp(self):
self.comparison = Comparison(user=owner, base_commit=base, head_commit=head)
asyncio.set_event_loop(asyncio.new_event_loop())

def test_returns_true_if_reverse_comparison_has_commits(self, get_adapter_mock):
commits = ["a", "b"]
get_adapter_mock.return_value = (
ComparisonHasUnmergedBaseCommitsTests.MockFetchDiffCoro(commits)
)
assert self.comparison.has_unmerged_base_commits is True

def test_returns_false_if_reverse_comparison_has_one_commit_or_less(
self, get_adapter_mock
):
commits = ["a"]
get_adapter_mock.return_value = (
ComparisonHasUnmergedBaseCommitsTests.MockFetchDiffCoro(commits)
)
assert self.comparison.has_unmerged_base_commits is False


class SegmentTests(TestCase):
def _report_lines(self, hits):
Expand Down
2 changes: 1 addition & 1 deletion timeseries/tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_perform_backfill(self, backfill_dataset):
)
assert res.status_code == 302

backfill_dataset.call_count == 2
assert backfill_dataset.call_count == 2
backfill_dataset.assert_any_call(
self.dataset1,
start_date=timezone.datetime(2000, 1, 1, tzinfo=timezone.utc),
Expand Down

0 comments on commit e245d32

Please sign in to comment.