Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 94 additions & 38 deletions tests/unit/oidc/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest

from tests.common.db.accounts import UserFactory
from tests.common.db.oidc import PendingGitHubPublisherFactory
from tests.common.db.oidc import GitHubPublisherFactory, PendingGitHubPublisherFactory
from tests.common.db.packaging import ProjectFactory
from warehouse.events.tags import EventTag
from warehouse.macaroons import caveats
Expand Down Expand Up @@ -69,13 +69,13 @@ def test_oidc_audience():
assert response == {"audience": "fakeaudience"}


def test_mint_token_from_oidc_not_enabled():
def test_mint_token_from_github_oidc_not_enabled():
request = pretend.stub(
response=pretend.stub(status=None),
flags=pretend.stub(disallow_oidc=lambda *a: True),
)

response = views.mint_token_from_oidc(request)
response = views.mint_token_from_oidc_github(request)
assert request.response.status == 422
assert response == {
"message": "Token request failed",
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_mint_token_from_oidc_not_enabled():
{"token": {}},
],
)
def test_mint_token_from_oidc_invalid_payload(body):
def test_mint_token_oidc_invalid_payload(body):
class Request:
def __init__(self):
self.response = pretend.stub(status=None)
Expand All @@ -120,7 +120,8 @@ def body(self):
return json.dumps(body)

req = Request()
resp = views.mint_token_from_oidc(req)
oidc_service = pretend.stub()
resp = views.mint_token(oidc_service, req)

assert req.response.status == 422
assert resp["message"] == "Token request failed"
Expand All @@ -142,7 +143,7 @@ def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails():
flags=pretend.stub(disallow_oidc=lambda *a: False),
)

response = views.mint_token_from_oidc(request)
response = views.mint_token(oidc_service, request)
assert request.response.status == 422
assert response == {
"message": "Token request failed",
Expand All @@ -154,9 +155,6 @@ def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails():
],
}

assert request.find_service.calls == [
pretend.call(IOIDCPublisherService, name="github")
]
assert oidc_service.verify_jwt_signature.calls == [pretend.call("faketoken")]


Expand All @@ -176,7 +174,7 @@ def test_mint_token_from_trusted_publisher_lookup_fails():
flags=pretend.stub(disallow_oidc=lambda *a: False),
)

response = views.mint_token_from_oidc(request)
response = views.mint_token_from_oidc_github(request)
assert request.response.status == 422
assert response == {
"message": "Token request failed",
Expand All @@ -202,7 +200,9 @@ def test_mint_token_from_trusted_publisher_lookup_fails():

def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_request):
project = ProjectFactory.create()
pending_publisher = PendingGitHubPublisherFactory.create(project_name=project.name)
pending_publisher = PendingGitHubPublisherFactory.create(
project_name=project.name,
)

db_request.flags.disallow_oidc = lambda f=None: False
db_request.body = json.dumps({"token": "faketoken"})
Expand All @@ -216,7 +216,7 @@ def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_reques
)
db_request.find_service = pretend.call_recorder(lambda *a, **kw: oidc_service)

resp = views.mint_token_from_oidc(db_request)
resp = views.mint_token(oidc_service, db_request)
assert db_request.response.status_code == 422
assert resp == {
"message": "Token request failed",
Expand All @@ -230,9 +230,6 @@ def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_reques

assert oidc_service.verify_jwt_signature.calls == [pretend.call("faketoken")]
assert oidc_service.find_publisher.calls == [pretend.call(claims, pending=True)]
assert db_request.find_service.calls == [
pretend.call(IOIDCPublisherService, name="github")
]


def test_mint_token_from_oidc_pending_publisher_ok(
Expand Down Expand Up @@ -279,7 +276,7 @@ def test_mint_token_from_oidc_pending_publisher_ok(
}
monkeypatch.setattr(views, "_ratelimiters", lambda r: ratelimiters)

resp = views.mint_token_from_oidc(db_request)
resp = views.mint_token_from_oidc_github(db_request)
assert resp["success"]
assert resp["token"].startswith("pypi-")

Expand Down Expand Up @@ -356,7 +353,7 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others(
}
monkeypatch.setattr(views, "_ratelimiters", lambda r: ratelimiters)

resp = views.mint_token_from_oidc(db_request)
resp = views.mint_token_from_oidc_github(db_request)
assert resp["success"]
assert resp["token"].startswith("pypi-")

Expand All @@ -374,6 +371,71 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others(
]


def test_mint_token_from_oidc_only_pending_publisher_fail(monkeypatch, db_request):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test tests an impossible state to reach as laid out in this comment and I'm not sure what to do about it.

pending_publisher = PendingGitHubPublisherFactory()

def _find_publisher(claims, pending=False):
return pending_publisher

oidc_service = pretend.stub(
verify_jwt_signature=pretend.call_recorder(
lambda token: {"ref": "someref", "sha": "somesha"}
),
find_publisher=pretend.call_recorder(_find_publisher),
reify_pending_publisher=pretend.call_recorder(
lambda *a, **kw: pending_publisher
),
)

db_request.body = json.dumps(
{
"token": (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI2ZTY3YjFjYi0yYjhkLTRi"
"ZTUtOTFjYi03NTdlZGIyZWM5NzAiLCJzdWIiOiJyZXBvOmZvby9iYXIiLCJhdWQiOiJwe"
"XBpIiwicmVmIjoiZmFrZSIsInNoYSI6ImZha2UiLCJyZXBvc2l0b3J5IjoiZm9vL2Jhci"
"IsInJlcG9zaXRvcnlfb3duZXIiOiJmb28iLCJyZXBvc2l0b3J5X293bmVyX2lkIjoiMTI"
"zIiwicnVuX2lkIjoiZmFrZSIsInJ1bl9udW1iZXIiOiJmYWtlIiwicnVuX2F0dGVtcHQi"
"OiIxIiwicmVwb3NpdG9yeV9pZCI6ImZha2UiLCJhY3Rvcl9pZCI6ImZha2UiLCJhY3Rvc"
"iI6ImZvbyIsIndvcmtmbG93IjoiZmFrZSIsImhlYWRfcmVmIjoiZmFrZSIsImJhc2Vfcm"
"VmIjoiZmFrZSIsImV2ZW50X25hbWUiOiJmYWtlIiwicmVmX3R5cGUiOiJmYWtlIiwiZW5"
"2aXJvbm1lbnQiOiJmYWtlIiwiam9iX3dvcmtmbG93X3JlZiI6ImZvby9iYXIvLmdpdGh1"
"Yi93b3JrZmxvd3MvZXhhbXBsZS55bWxAZmFrZSIsImlzcyI6Imh0dHBzOi8vdG9rZW4uY"
"WN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20iLCJuYmYiOjE2NTA2NjMyNjUsImV4cC"
"I6MTY1MDY2NDE2NSwiaWF0IjoxNjUwNjYzODY1fQ.f-FMv5FF5sdxAWeUilYDt9NoE7Et"
"0vbdNhK32c2oC-E"
)
}
)

ratelimiter = pretend.stub(clear=pretend.call_recorder(lambda id: None))
ratelimiters = {
"user.oidc": ratelimiter,
"ip.oidc": ratelimiter,
}
monkeypatch.setattr(views, "_ratelimiters", lambda r: ratelimiters)

send_pending_trusted_publisher_invalidated_email = pretend.call_recorder(
lambda *a, **kw: None
)
monkeypatch.setattr(
views,
"send_pending_trusted_publisher_invalidated_email",
send_pending_trusted_publisher_invalidated_email,
)

response = views.mint_token(oidc_service, db_request)

assert response == {
"message": "Token request failed",
"errors": [
{
"code": "invalid-publisher",
"description": ("valid token, but no corresponding publisher"),
}
],
}


@pytest.mark.parametrize(
("claims_in_token", "claims_input"),
[
Expand All @@ -383,7 +445,7 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others(
],
)
def test_mint_token_from_oidc_no_pending_publisher_ok(
monkeypatch, claims_in_token, claims_input
monkeypatch, db_request, claims_in_token, claims_input
):
time = pretend.stub(time=pretend.call_recorder(lambda: 0))
monkeypatch.setattr(views, "time", time)
Expand All @@ -392,18 +454,16 @@ def test_mint_token_from_oidc_no_pending_publisher_ok(
id="fakeprojectid",
record_event=pretend.call_recorder(lambda **kw: None),
)
publisher = pretend.stub(
Copy link
Contributor Author

@th3coop th3coop Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was changed to use the factory due to the type check here and here which was added because of mypy errors in the code after those type checks.

id="fakepublisherid",
projects=[project],
publisher_name="fakepublishername",
publisher_url=lambda x=None: "https://fake/url",
)

publisher = GitHubPublisherFactory()
monkeypatch.setattr(publisher.__class__, "projects", [project])
publisher.publisher_url = pretend.call_recorder(lambda **kw: "https://fake/url")
# NOTE: Can't set __str__ using pretend.stub()
monkeypatch.setattr(publisher.__class__, "__str__", lambda s: "fakespecifier")

def _find_publisher(claims, pending=False):
if pending:
raise errors.InvalidPublisherError
return None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is wrong I'm pretty sure but I'm not sure what to do about it.

I made this change because the coverage report was telling me that this line was never getting hit when the tests run.

This is true and correct because if I call oidc_service.find_publisher(claims, pending=False) and then I will NEVER get back anything but an OIDCPublisher, OR it will raise an exception.

The static type checking doesn't know that though, it just knows that OIDCPublisher | PendingOIDCPublisher will be returned.

The first thing that comes to mind to me is that _find_publisher shouldn't return two different types based on one of the inputs to the method. There should be a method per Class.

There are obviously other options but that's the first one that comes to mind for me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, this seems wrong, I think the line you're having trouble with is probably unnecessary, see the comment there.

else:
return publisher

Expand All @@ -426,16 +486,11 @@ def find_service(iface, **kw):
return macaroon_service
assert False, iface

request = pretend.stub(
response=pretend.stub(status=None),
body=json.dumps({"token": "faketoken"}),
find_service=find_service,
domain="fakedomain",
remote_addr="0.0.0.0",
flags=pretend.stub(disallow_oidc=lambda *a: False),
)
monkeypatch.setattr(db_request, "find_service", find_service)
monkeypatch.setattr(db_request, "body", json.dumps({"token": "faketoken"}))
monkeypatch.setattr(db_request, "domain", "fakedomain")

response = views.mint_token_from_oidc(request)
response = views.mint_token(oidc_service, db_request)
assert response == {
"success": True,
"token": "raw-macaroon",
Expand All @@ -446,28 +501,29 @@ def find_service(iface, **kw):
pretend.call(claims_in_token, pending=True),
pretend.call(claims_in_token, pending=False),
]

assert macaroon_service.create_macaroon.calls == [
pretend.call(
"fakedomain",
f"OpenID token: fakespecifier ({datetime.fromtimestamp(0).isoformat()})",
[
caveats.OIDCPublisher(
oidc_publisher_id="fakepublisherid",
oidc_publisher_id=str(publisher.id),
),
caveats.ProjectID(project_ids=["fakeprojectid"]),
caveats.Expiration(expires_at=900, not_before=0),
],
oidc_publisher_id="fakepublisherid",
oidc_publisher_id=str(publisher.id),
additional={"oidc": claims_input},
)
]
assert project.record_event.calls == [
pretend.call(
tag=EventTag.Project.ShortLivedAPITokenAdded,
request=request,
request=db_request,
additional={
"expires": 900,
"publisher_name": "fakepublishername",
"publisher_name": "GitHub",
"publisher_url": "https://fake/url",
},
)
Expand Down
4 changes: 3 additions & 1 deletion warehouse/oidc/models/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def publisher_name(self) -> str: # pragma: no cover
# Only concrete subclasses are constructed.
raise NotImplementedError

def publisher_url(self, claims=None) -> str | None: # pragma: no cover
def publisher_url(
self, claims: SignedClaims | None = None
) -> str | None: # pragma: no cover
"""
NOTE: This is **NOT** a `@property` because we pass `claims` to it.
When calling, make sure to use `publisher_url()`
Expand Down
5 changes: 3 additions & 2 deletions warehouse/oidc/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None:

def find_publisher(
self, signed_claims: SignedClaims, *, pending: bool = False
) -> OIDCPublisher | PendingOIDCPublisher | None:
) -> OIDCPublisher | PendingOIDCPublisher:
"""Returns a publisher for the given claims, or raises an error."""
metrics_tags = [f"publisher:{self.publisher}"]
self.metrics.increment(
"warehouse.oidc.find_publisher.attempt",
Expand All @@ -306,7 +307,7 @@ def find_publisher(
)
raise e

def reify_pending_publisher(self, pending_publisher, project):
def reify_pending_publisher(self, pending_publisher, project) -> OIDCPublisher:
new_publisher = pending_publisher.reify(self.db)
project.oidc_publishers.append(new_publisher)
return new_publisher
Expand Down
Loading