diff --git a/tests/unit/oidc/test_views.py b/tests/unit/oidc/test_views.py index b32c53ec3034..e58c9bd61e10 100644 --- a/tests/unit/oidc/test_views.py +++ b/tests/unit/oidc/test_views.py @@ -18,7 +18,7 @@ import pytest from tests.common.db.accounts import UserFactory -from tests.common.db.oidc import PendingGitHubPublisherFactory +from tests.common.db.oidc import GitHubPublisherFactory, PendingGitHubPublisherFactory from tests.common.db.packaging import ProjectFactory from warehouse.events.tags import EventTag from warehouse.macaroons import caveats @@ -69,13 +69,13 @@ def test_oidc_audience(): assert response == {"audience": "fakeaudience"} -def test_mint_token_from_oidc_not_enabled(): +def test_mint_token_from_github_oidc_not_enabled(): request = pretend.stub( response=pretend.stub(status=None), flags=pretend.stub(disallow_oidc=lambda *a: True), ) - response = views.mint_token_from_oidc(request) + response = views.mint_token_from_oidc_github(request) assert request.response.status == 422 assert response == { "message": "Token request failed", @@ -109,7 +109,7 @@ def test_mint_token_from_oidc_not_enabled(): {"token": {}}, ], ) -def test_mint_token_from_oidc_invalid_payload(body): +def test_mint_token_oidc_invalid_payload(body): class Request: def __init__(self): self.response = pretend.stub(status=None) @@ -120,7 +120,8 @@ def body(self): return json.dumps(body) req = Request() - resp = views.mint_token_from_oidc(req) + oidc_service = pretend.stub() + resp = views.mint_token(oidc_service, req) assert req.response.status == 422 assert resp["message"] == "Token request failed" @@ -142,7 +143,7 @@ def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails(): flags=pretend.stub(disallow_oidc=lambda *a: False), ) - response = views.mint_token_from_oidc(request) + response = views.mint_token(oidc_service, request) assert request.response.status == 422 assert response == { "message": "Token request failed", @@ -154,9 +155,6 @@ def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails(): ], } - assert request.find_service.calls == [ - pretend.call(IOIDCPublisherService, name="github") - ] assert oidc_service.verify_jwt_signature.calls == [pretend.call("faketoken")] @@ -176,7 +174,7 @@ def test_mint_token_from_trusted_publisher_lookup_fails(): flags=pretend.stub(disallow_oidc=lambda *a: False), ) - response = views.mint_token_from_oidc(request) + response = views.mint_token_from_oidc_github(request) assert request.response.status == 422 assert response == { "message": "Token request failed", @@ -202,7 +200,9 @@ def test_mint_token_from_trusted_publisher_lookup_fails(): def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_request): project = ProjectFactory.create() - pending_publisher = PendingGitHubPublisherFactory.create(project_name=project.name) + pending_publisher = PendingGitHubPublisherFactory.create( + project_name=project.name, + ) db_request.flags.disallow_oidc = lambda f=None: False db_request.body = json.dumps({"token": "faketoken"}) @@ -216,7 +216,7 @@ def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_reques ) db_request.find_service = pretend.call_recorder(lambda *a, **kw: oidc_service) - resp = views.mint_token_from_oidc(db_request) + resp = views.mint_token(oidc_service, db_request) assert db_request.response.status_code == 422 assert resp == { "message": "Token request failed", @@ -230,9 +230,6 @@ def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_reques assert oidc_service.verify_jwt_signature.calls == [pretend.call("faketoken")] assert oidc_service.find_publisher.calls == [pretend.call(claims, pending=True)] - assert db_request.find_service.calls == [ - pretend.call(IOIDCPublisherService, name="github") - ] def test_mint_token_from_oidc_pending_publisher_ok( @@ -279,7 +276,7 @@ def test_mint_token_from_oidc_pending_publisher_ok( } monkeypatch.setattr(views, "_ratelimiters", lambda r: ratelimiters) - resp = views.mint_token_from_oidc(db_request) + resp = views.mint_token_from_oidc_github(db_request) assert resp["success"] assert resp["token"].startswith("pypi-") @@ -356,7 +353,7 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others( } monkeypatch.setattr(views, "_ratelimiters", lambda r: ratelimiters) - resp = views.mint_token_from_oidc(db_request) + resp = views.mint_token_from_oidc_github(db_request) assert resp["success"] assert resp["token"].startswith("pypi-") @@ -374,6 +371,71 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others( ] +def test_mint_token_from_oidc_only_pending_publisher_fail(monkeypatch, db_request): + pending_publisher = PendingGitHubPublisherFactory() + + def _find_publisher(claims, pending=False): + return pending_publisher + + oidc_service = pretend.stub( + verify_jwt_signature=pretend.call_recorder( + lambda token: {"ref": "someref", "sha": "somesha"} + ), + find_publisher=pretend.call_recorder(_find_publisher), + reify_pending_publisher=pretend.call_recorder( + lambda *a, **kw: pending_publisher + ), + ) + + db_request.body = json.dumps( + { + "token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI2ZTY3YjFjYi0yYjhkLTRi" + "ZTUtOTFjYi03NTdlZGIyZWM5NzAiLCJzdWIiOiJyZXBvOmZvby9iYXIiLCJhdWQiOiJwe" + "XBpIiwicmVmIjoiZmFrZSIsInNoYSI6ImZha2UiLCJyZXBvc2l0b3J5IjoiZm9vL2Jhci" + "IsInJlcG9zaXRvcnlfb3duZXIiOiJmb28iLCJyZXBvc2l0b3J5X293bmVyX2lkIjoiMTI" + "zIiwicnVuX2lkIjoiZmFrZSIsInJ1bl9udW1iZXIiOiJmYWtlIiwicnVuX2F0dGVtcHQi" + "OiIxIiwicmVwb3NpdG9yeV9pZCI6ImZha2UiLCJhY3Rvcl9pZCI6ImZha2UiLCJhY3Rvc" + "iI6ImZvbyIsIndvcmtmbG93IjoiZmFrZSIsImhlYWRfcmVmIjoiZmFrZSIsImJhc2Vfcm" + "VmIjoiZmFrZSIsImV2ZW50X25hbWUiOiJmYWtlIiwicmVmX3R5cGUiOiJmYWtlIiwiZW5" + "2aXJvbm1lbnQiOiJmYWtlIiwiam9iX3dvcmtmbG93X3JlZiI6ImZvby9iYXIvLmdpdGh1" + "Yi93b3JrZmxvd3MvZXhhbXBsZS55bWxAZmFrZSIsImlzcyI6Imh0dHBzOi8vdG9rZW4uY" + "WN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20iLCJuYmYiOjE2NTA2NjMyNjUsImV4cC" + "I6MTY1MDY2NDE2NSwiaWF0IjoxNjUwNjYzODY1fQ.f-FMv5FF5sdxAWeUilYDt9NoE7Et" + "0vbdNhK32c2oC-E" + ) + } + ) + + ratelimiter = pretend.stub(clear=pretend.call_recorder(lambda id: None)) + ratelimiters = { + "user.oidc": ratelimiter, + "ip.oidc": ratelimiter, + } + monkeypatch.setattr(views, "_ratelimiters", lambda r: ratelimiters) + + send_pending_trusted_publisher_invalidated_email = pretend.call_recorder( + lambda *a, **kw: None + ) + monkeypatch.setattr( + views, + "send_pending_trusted_publisher_invalidated_email", + send_pending_trusted_publisher_invalidated_email, + ) + + response = views.mint_token(oidc_service, db_request) + + assert response == { + "message": "Token request failed", + "errors": [ + { + "code": "invalid-publisher", + "description": ("valid token, but no corresponding publisher"), + } + ], + } + + @pytest.mark.parametrize( ("claims_in_token", "claims_input"), [ @@ -383,7 +445,7 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others( ], ) def test_mint_token_from_oidc_no_pending_publisher_ok( - monkeypatch, claims_in_token, claims_input + monkeypatch, db_request, claims_in_token, claims_input ): time = pretend.stub(time=pretend.call_recorder(lambda: 0)) monkeypatch.setattr(views, "time", time) @@ -392,18 +454,16 @@ def test_mint_token_from_oidc_no_pending_publisher_ok( id="fakeprojectid", record_event=pretend.call_recorder(lambda **kw: None), ) - publisher = pretend.stub( - id="fakepublisherid", - projects=[project], - publisher_name="fakepublishername", - publisher_url=lambda x=None: "https://fake/url", - ) + + publisher = GitHubPublisherFactory() + monkeypatch.setattr(publisher.__class__, "projects", [project]) + publisher.publisher_url = pretend.call_recorder(lambda **kw: "https://fake/url") # NOTE: Can't set __str__ using pretend.stub() monkeypatch.setattr(publisher.__class__, "__str__", lambda s: "fakespecifier") def _find_publisher(claims, pending=False): if pending: - raise errors.InvalidPublisherError + return None else: return publisher @@ -426,16 +486,11 @@ def find_service(iface, **kw): return macaroon_service assert False, iface - request = pretend.stub( - response=pretend.stub(status=None), - body=json.dumps({"token": "faketoken"}), - find_service=find_service, - domain="fakedomain", - remote_addr="0.0.0.0", - flags=pretend.stub(disallow_oidc=lambda *a: False), - ) + monkeypatch.setattr(db_request, "find_service", find_service) + monkeypatch.setattr(db_request, "body", json.dumps({"token": "faketoken"})) + monkeypatch.setattr(db_request, "domain", "fakedomain") - response = views.mint_token_from_oidc(request) + response = views.mint_token(oidc_service, db_request) assert response == { "success": True, "token": "raw-macaroon", @@ -446,28 +501,29 @@ def find_service(iface, **kw): pretend.call(claims_in_token, pending=True), pretend.call(claims_in_token, pending=False), ] + assert macaroon_service.create_macaroon.calls == [ pretend.call( "fakedomain", f"OpenID token: fakespecifier ({datetime.fromtimestamp(0).isoformat()})", [ caveats.OIDCPublisher( - oidc_publisher_id="fakepublisherid", + oidc_publisher_id=str(publisher.id), ), caveats.ProjectID(project_ids=["fakeprojectid"]), caveats.Expiration(expires_at=900, not_before=0), ], - oidc_publisher_id="fakepublisherid", + oidc_publisher_id=str(publisher.id), additional={"oidc": claims_input}, ) ] assert project.record_event.calls == [ pretend.call( tag=EventTag.Project.ShortLivedAPITokenAdded, - request=request, + request=db_request, additional={ "expires": 900, - "publisher_name": "fakepublishername", + "publisher_name": "GitHub", "publisher_url": "https://fake/url", }, ) diff --git a/warehouse/oidc/models/_core.py b/warehouse/oidc/models/_core.py index 8330449a4f4f..63765ccaa72e 100644 --- a/warehouse/oidc/models/_core.py +++ b/warehouse/oidc/models/_core.py @@ -232,7 +232,9 @@ def publisher_name(self) -> str: # pragma: no cover # Only concrete subclasses are constructed. raise NotImplementedError - def publisher_url(self, claims=None) -> str | None: # pragma: no cover + def publisher_url( + self, claims: SignedClaims | None = None + ) -> str | None: # pragma: no cover """ NOTE: This is **NOT** a `@property` because we pass `claims` to it. When calling, make sure to use `publisher_url()` diff --git a/warehouse/oidc/services.py b/warehouse/oidc/services.py index c5635dc7c176..d3921d4fc30d 100644 --- a/warehouse/oidc/services.py +++ b/warehouse/oidc/services.py @@ -282,7 +282,8 @@ def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None: def find_publisher( self, signed_claims: SignedClaims, *, pending: bool = False - ) -> OIDCPublisher | PendingOIDCPublisher | None: + ) -> OIDCPublisher | PendingOIDCPublisher: + """Returns a publisher for the given claims, or raises an error.""" metrics_tags = [f"publisher:{self.publisher}"] self.metrics.increment( "warehouse.oidc.find_publisher.attempt", @@ -306,7 +307,7 @@ def find_publisher( ) raise e - def reify_pending_publisher(self, pending_publisher, project): + def reify_pending_publisher(self, pending_publisher, project) -> OIDCPublisher: new_publisher = pending_publisher.reify(self.db) project.oidc_publishers.append(new_publisher) return new_publisher diff --git a/warehouse/oidc/views.py b/warehouse/oidc/views.py index 251e13677c1c..3ca3861fa9ab 100644 --- a/warehouse/oidc/views.py +++ b/warehouse/oidc/views.py @@ -13,8 +13,10 @@ import time from datetime import datetime +from typing import TypedDict from pydantic import BaseModel, StrictStr, ValidationError +from pyramid.request import Request from pyramid.response import Response from pyramid.view import view_config from sqlalchemy import func @@ -24,19 +26,33 @@ from warehouse.events.tags import EventTag from warehouse.macaroons import caveats from warehouse.macaroons.interfaces import IMacaroonService +from warehouse.macaroons.services import DatabaseMacaroonService from warehouse.oidc.errors import InvalidPublisherError from warehouse.oidc.interfaces import IOIDCPublisherService -from warehouse.oidc.models import PendingOIDCPublisher +from warehouse.oidc.models import OIDCPublisher, PendingOIDCPublisher +from warehouse.oidc.services import OIDCPublisherService from warehouse.packaging.interfaces import IProjectService from warehouse.packaging.models import ProjectFactory from warehouse.rate_limiting.interfaces import IRateLimiter +class Error(TypedDict): + code: str + description: str + + +class JsonRespone(TypedDict, total=False): + message: str | None + errors: list[Error] | None + token: StrictStr | None + success: bool | None + + class TokenPayload(BaseModel): token: StrictStr -def _ratelimiters(request): +def _ratelimiters(request: Request) -> dict[str, IRateLimiter]: return { "user.oidc": request.find_service( IRateLimiter, name="user_oidc.publisher.register" @@ -47,6 +63,15 @@ def _ratelimiters(request): } +def _invalid(errors: list[Error], request: Request) -> JsonRespone: + request.response.status = 422 + + return { + "message": "Token request failed", + "errors": errors, + } + + @view_config( route_name="oidc.audience", require_methods=["GET"], @@ -54,13 +79,13 @@ def _ratelimiters(request): require_csrf=False, has_translations=False, ) -def oidc_audience(request): +def oidc_audience(request: Request): if request.flags.disallow_oidc(): return Response( status=403, json={"message": "Trusted publishing functionality not enabled"} ) - audience = request.registry.settings["warehouse.oidc.audience"] + audience: str = request.registry.settings["warehouse.oidc.audience"] return {"audience": audience} @@ -71,39 +96,46 @@ def oidc_audience(request): require_csrf=False, has_translations=True, ) -def mint_token_from_oidc(request): - def _invalid(errors): - request.response.status = 422 - return {"message": "Token request failed", "errors": errors} - +def mint_token_from_oidc_github(request: Request): if request.flags.disallow_oidc(AdminFlagValue.DISALLOW_GITHUB_OIDC): return _invalid( errors=[ { "code": "not-enabled", - "description": ( - "GitHub-based trusted publishing functionality not enabled" - ), + "description": "GitHub-based trusted publishing functionality not enabled", # noqa } - ] + ], + request=request, ) + # For the time being, GitHub is our only OIDC publisher. + # In the future, this should locate the correct service based on an + # identifier in the request body. + oidc_service: OIDCPublisherService = request.find_service( + IOIDCPublisherService, name="github" + ) + + return mint_token(oidc_service, request) + + +def mint_token(oidc_service: OIDCPublisherService, request: Request) -> JsonRespone: + unverified_jwt: str try: payload = TokenPayload.model_validate_json(request.body) unverified_jwt = payload.token except ValidationError as exc: - return _invalid(errors=[{"code": "invalid-payload", "description": str(exc)}]) + return _invalid( + errors=[{"code": "invalid-payload", "description": str(exc)}], + request=request, + ) - # For the time being, GitHub is our only OIDC publisher. - # In the future, this should locate the correct service based on an - # identifier in the request body. - oidc_service = request.find_service(IOIDCPublisherService, name="github") claims = oidc_service.verify_jwt_signature(unverified_jwt) if not claims: return _invalid( errors=[ {"code": "invalid-token", "description": "malformed or invalid token"} - ] + ], + request=request, ) # First, try to find a pending publisher. @@ -111,56 +143,59 @@ def _invalid(errors): pending_publisher = oidc_service.find_publisher(claims, pending=True) factory = ProjectFactory(request) - # If the project already exists, this pending publisher is no longer - # valid and needs to be removed. - # NOTE: This is mostly a sanity check, since we dispose of invalidated - # pending publishers below. - if pending_publisher.project_name in factory: - request.db.delete(pending_publisher) - return _invalid( - errors=[ - { - "code": "invalid-pending-publisher", - "description": "valid token, but project already exists", - } - ] - ) + if isinstance(pending_publisher, PendingOIDCPublisher): + # If the project already exists, this pending publisher is no longer + # valid and needs to be removed. + # NOTE: This is mostly a sanity check, since we dispose of invalidated + # pending publishers below. + if pending_publisher.project_name in factory: + request.db.delete(pending_publisher) + return _invalid( + errors=[ + { + "code": "invalid-pending-publisher", + "description": "valid token, but project already exists", + } + ], + request=request, + ) - # Create the new project, and reify the pending publisher against it. - project_service = request.find_service(IProjectService) - new_project = project_service.create_project( - pending_publisher.project_name, - pending_publisher.added_by, - request, - ratelimited=False, - ) - oidc_service.reify_pending_publisher(pending_publisher, new_project) - - # Successfully converting a pending publisher into a normal publisher - # is a positive signal, so we reset the associated ratelimits. - ratelimiters = _ratelimiters(request) - ratelimiters["user.oidc"].clear(pending_publisher.added_by.id) - ratelimiters["ip.oidc"].clear(request.remote_addr) - - # There might be other pending publishers for the same project name, - # which we've now invalidated by creating the project. These would - # be disposed of on use, but we explicitly dispose of them here while - # also sending emails to their owners. - stale_pending_publishers = ( - request.db.query(PendingOIDCPublisher) - .filter( - func.normalize_pep426_name(PendingOIDCPublisher.project_name) - == func.normalize_pep426_name(pending_publisher.project_name) - ) - .all() - ) - for stale_publisher in stale_pending_publishers: - send_pending_trusted_publisher_invalidated_email( + # Create the new project, and reify the pending publisher against it. + project_service = request.find_service(IProjectService) + new_project = project_service.create_project( + pending_publisher.project_name, + pending_publisher.added_by, request, - stale_publisher.added_by, - project_name=stale_publisher.project_name, + ratelimited=False, ) - request.db.delete(stale_publisher) + + oidc_service.reify_pending_publisher(pending_publisher, new_project) + + # Successfully converting a pending publisher into a normal publisher + # is a positive signal, so we reset the associated ratelimits. + ratelimiters = _ratelimiters(request) + ratelimiters["user.oidc"].clear(pending_publisher.added_by.id) + ratelimiters["ip.oidc"].clear(request.remote_addr) + + # There might be other pending publishers for the same project name, + # which we've now invalidated by creating the project. These would + # be disposed of on use, but we explicitly dispose of them here while + # also sending emails to their owners. + stale_pending_publishers = ( + request.db.query(PendingOIDCPublisher) + .filter( + func.normalize_pep426_name(PendingOIDCPublisher.project_name) + == func.normalize_pep426_name(pending_publisher.project_name) + ) + .all() + ) + for stale_publisher in stale_pending_publishers: + send_pending_trusted_publisher_invalidated_email( + request, + stale_publisher.added_by, + project_name=stale_publisher.project_name, + ) + request.db.delete(stale_publisher) except InvalidPublisherError: # If the claim set isn't valid for a pending publisher, it's OK, we # will try finding a regular publisher @@ -178,14 +213,30 @@ def _invalid(errors): "code": "invalid-publisher", "description": f"valid token, but no corresponding publisher ({e})", } - ] + ], + request=request, ) + if not isinstance(publisher, OIDCPublisher): + # This should be impossible, but we have to perform this type check to + # appease mypy otherwise we get type errors in the code after this + # point. + return _invalid( + errors=[ + { + "code": "invalid-publisher", + "description": "valid token, but no corresponding publisher", + } + ], + request=request, + ) # At this point, we've verified that the given JWT is valid for the given # project. All we need to do is mint a new token. # NOTE: For OIDC-minted API tokens, the Macaroon's description string # is purely an implementation detail and is not displayed to the user. - macaroon_service = request.find_service(IMacaroonService, context=None) + macaroon_service: DatabaseMacaroonService = request.find_service( + IMacaroonService, context=None + ) not_before = int(time.time()) expires_at = not_before + 900 serialized, dm = macaroon_service.create_macaroon( @@ -201,7 +252,7 @@ def _invalid(errors): caveats.ProjectID(project_ids=[str(p.id) for p in publisher.projects]), caveats.Expiration(expires_at=expires_at, not_before=not_before), ], - oidc_publisher_id=publisher.id, + oidc_publisher_id=str(publisher.id), additional={"oidc": {"ref": claims.get("ref"), "sha": claims.get("sha")}}, ) for project in publisher.projects: