diff --git a/tests/conftest.py b/tests/conftest.py index 1a205f4bb79e..204daa776c18 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import os.path import re @@ -379,6 +380,57 @@ def oidc_service(db_session): ) +@pytest.fixture +def dummy_oidc_jwt(): + # { + # "jti": "6e67b1cb-2b8d-4be5-91cb-757edb2ec970", + # "sub": "repo:foo/bar", + # "aud": "pypi", + # "ref": "fake", + # "sha": "fake", + # "repository": "foo/bar", + # "repository_owner": "foo", + # "repository_owner_id": "123", + # "run_id": "fake", + # "run_number": "fake", + # "run_attempt": "1", + # "repository_id": "fake", + # "actor_id": "fake", + # "actor": "foo", + # "workflow": "fake", + # "head_ref": "fake", + # "base_ref": "fake", + # "event_name": "fake", + # "ref_type": "fake", + # "environment": "fake", + # "job_workflow_ref": "foo/bar/.github/workflows/example.yml@fake", + # "iss": "https://token.actions.githubusercontent.com", + # "nbf": 1650663265, + # "exp": 1650664165, + # "iat": 1650663865 + # } + return ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI2ZTY3YjFjYi0yYjhkLTRiZ" + "TUtOTFjYi03NTdlZGIyZWM5NzAiLCJzdWIiOiJyZXBvOmZvby9iYXIiLCJhdWQiOiJweXB" + "pIiwicmVmIjoiZmFrZSIsInNoYSI6ImZha2UiLCJyZXBvc2l0b3J5IjoiZm9vL2JhciIsI" + "nJlcG9zaXRvcnlfb3duZXIiOiJmb28iLCJyZXBvc2l0b3J5X293bmVyX2lkIjoiMTIzIiw" + "icnVuX2lkIjoiZmFrZSIsInJ1bl9udW1iZXIiOiJmYWtlIiwicnVuX2F0dGVtcHQiOiIxI" + "iwicmVwb3NpdG9yeV9pZCI6ImZha2UiLCJhY3Rvcl9pZCI6ImZha2UiLCJhY3RvciI6ImZ" + "vbyIsIndvcmtmbG93IjoiZmFrZSIsImhlYWRfcmVmIjoiZmFrZSIsImJhc2VfcmVmIjoiZ" + "mFrZSIsImV2ZW50X25hbWUiOiJmYWtlIiwicmVmX3R5cGUiOiJmYWtlIiwiZW52aXJvbm1" + "lbnQiOiJmYWtlIiwiam9iX3dvcmtmbG93X3JlZiI6ImZvby9iYXIvLmdpdGh1Yi93b3JrZ" + "mxvd3MvZXhhbXBsZS55bWxAZmFrZSIsImlzcyI6Imh0dHBzOi8vdG9rZW4uYWN0aW9ucy5" + "naXRodWJ1c2VyY29udGVudC5jb20iLCJuYmYiOjE2NTA2NjMyNjUsImV4cCI6MTY1MDY2N" + "DE2NSwiaWF0IjoxNjUwNjYzODY1fQ.f-FMv5FF5sdxAWeUilYDt9NoE7Et0vbdNhK32c2o" + "C-E" + ) + + +@pytest.fixture +def dummy_oidc_payload(dummy_oidc_jwt): + return json.dumps({"token": dummy_oidc_jwt}) + + @pytest.fixture def macaroon_service(db_session): return macaroon_services.DatabaseMacaroonService(db_session) diff --git a/tests/unit/oidc/test_utils.py b/tests/unit/oidc/test_utils.py index 1cd6a86a460d..8b8bc8af3e52 100644 --- a/tests/unit/oidc/test_utils.py +++ b/tests/unit/oidc/test_utils.py @@ -115,3 +115,21 @@ def test_oidc_context_principals(): Authenticated, "oidc:17", ] + + +def test_oidc_maps_consistent(): + # Our various mappings should have equivalent cardinalities. + assert len(utils.OIDC_ISSUER_URLS) == len(utils.OIDC_ISSUER_SERVICE_NAMES) + assert len(utils.OIDC_ISSUER_URLS) == len(utils.OIDC_ISSUER_ADMIN_FLAGS) + assert len(utils.OIDC_ISSUER_URLS) == len(utils.OIDC_PUBLISHER_CLASSES) + + for iss in utils.OIDC_ISSUER_URLS: + # Each issuer should be present in each mapping. + assert iss in utils.OIDC_ISSUER_SERVICE_NAMES + assert iss in utils.OIDC_ISSUER_ADMIN_FLAGS + assert iss in utils.OIDC_PUBLISHER_CLASSES + + for class_map in utils.OIDC_PUBLISHER_CLASSES.values(): + # The class mapping for pending and non-pending publisher models + # should be distinct. + assert class_map[True] != class_map[False] diff --git a/tests/unit/oidc/test_views.py b/tests/unit/oidc/test_views.py index 66b26646c86d..578607df1efa 100644 --- a/tests/unit/oidc/test_views.py +++ b/tests/unit/oidc/test_views.py @@ -18,7 +18,7 @@ import pytest from tests.common.db.accounts import UserFactory -from tests.common.db.oidc import PendingGitHubPublisherFactory +from tests.common.db.oidc import GitHubPublisherFactory, PendingGitHubPublisherFactory from tests.common.db.packaging import ProjectFactory from warehouse.events.tags import EventTag from warehouse.macaroons import caveats @@ -70,8 +70,9 @@ def test_oidc_audience(): assert response == {"audience": "fakeaudience"} -def test_mint_token_from_oidc_not_enabled(): +def test_mint_token_from_oidc_not_enabled(dummy_oidc_payload): request = pretend.stub( + body=dummy_oidc_payload, response=pretend.stub(status=None), flags=pretend.stub(disallow_oidc=lambda *a: True), ) @@ -83,9 +84,7 @@ def test_mint_token_from_oidc_not_enabled(): "errors": [ { "code": "not-enabled", - "description": ( - "GitHub-based trusted publishing functionality not enabled" - ), + "description": "github trusted publishing functionality not enabled", } ], } @@ -132,18 +131,164 @@ def body(self): assert isinstance(err["description"], str) -def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails(): +@pytest.mark.parametrize( + "body", + [ + {"token": "not-a-jwt"}, + { + # Well-formed JWT, but no `iss` claim + "token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwib" + "mFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fw" + "pMeJf36POk6yJV_adQssw5c" + ) + }, + ], +) +def test_mint_token_from_oidc_invalid_payload_malformed_jwt(body): + class Request: + def __init__(self): + self.response = pretend.stub(status=None) + self.flags = pretend.stub(disallow_oidc=lambda *a: False) + + @property + def body(self): + return json.dumps(body) + + def find_service(self, *a, **kw): + return pretend.stub(increment=pretend.call_recorder(lambda s: None)) + + req = Request() + resp = views.mint_token_from_oidc(req) + + assert req.response.status == 422 + assert resp["message"] == "Token request failed" + assert isinstance(resp["errors"], list) + for err in resp["errors"]: + assert isinstance(err, dict) + assert err["code"] == "invalid-payload" + assert err["description"] == "malformed JWT" + + +def test_mint_token_from_oidc_jwt_decode_leaky_exception( + monkeypatch, dummy_oidc_payload +): + class Request: + def __init__(self): + self.response = pretend.stub(status=None) + self.flags = pretend.stub(disallow_oidc=lambda *a: False) + + @property + def body(self): + return dummy_oidc_payload + + def find_service(self, *a, **kw): + return pretend.stub(increment=pretend.call_recorder(lambda s: None)) + + capture_message = pretend.call_recorder(lambda s: None) + monkeypatch.setattr(views.sentry_sdk, "capture_message", capture_message) + monkeypatch.setattr(views.jwt, "decode", pretend.raiser(ValueError("oops"))) + + req = Request() + resp = views.mint_token_from_oidc(req) + + assert capture_message.calls == [ + pretend.call("jwt.decode raised generic error: oops") + ] + + assert req.response.status == 422 + assert resp["message"] == "Token request failed" + assert isinstance(resp["errors"], list) + for err in resp["errors"]: + assert isinstance(err, dict) + assert err["code"] == "invalid-payload" + assert err["description"] == "malformed JWT" + + +def test_mint_token_from_oidc_unknown_issuer(): + class Request: + def __init__(self): + self.response = pretend.stub(status=None) + self.flags = pretend.stub(disallow_oidc=lambda *a: False) + + @property + def body(self): + return json.dumps( + { + "token": ( + # iss: nonexistent-issuer + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ" + "ub25leGlzdGVudC1pc3N1ZXIifQ.TYGmZaQXhjS3KA8o3POV" + "HeiD3FR5bz4X6UhRA4ykTFM" + ) + } + ) + + req = Request() + resp = views.mint_token_from_oidc(req) + + assert req.response.status == 422 + assert resp["message"] == "Token request failed" + assert isinstance(resp["errors"], list) + for err in resp["errors"]: + assert isinstance(err, dict) + assert err["code"] == "invalid-payload" + assert err["description"] == "unknown trusted publishing issuer" + + +@pytest.mark.parametrize( + ("token", "service_name"), + [ + ( + ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL3Rva2Vu" + "LmFjdGlvbnMuZ2l0aHVidXNlcmNvbnRlbnQuY29tIn0.saN7OFQBav8qXzgMCfERf" + "ZWPGfHu-0EEQMlVyO5UVdQ" + ), + "github", + ), + ( + ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2FjY291b" + "nRzLmdvb2dsZS5jb20ifQ.2RJ6Y52Rap0LEj61yBGDokUg8r92SYQq6l3cflSWBVI" + ), + "google", + ), + ], +) +def test_mint_token_from_oidc_creates_expected_service( + monkeypatch, token, service_name +): + mint_token = pretend.call_recorder(lambda *a: pretend.stub()) + monkeypatch.setattr(views, "mint_token", mint_token) + + oidc_service = pretend.stub() + request = pretend.stub( + response=pretend.stub(status=None), + find_service=pretend.call_recorder(lambda cls, **kw: oidc_service), + flags=pretend.stub(disallow_oidc=lambda *a: False), + body=json.dumps({"token": token}), + ) + + views.mint_token_from_oidc(request) + + assert request.find_service.calls == [ + pretend.call(IOIDCPublisherService, name=service_name) + ] + assert mint_token.calls == [pretend.call(oidc_service, token, request)] + + +def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails(dummy_oidc_jwt): oidc_service = pretend.stub( verify_jwt_signature=pretend.call_recorder(lambda token: None), ) request = pretend.stub( response=pretend.stub(status=None), - body=json.dumps({"token": "faketoken"}), find_service=pretend.call_recorder(lambda cls, **kw: oidc_service), flags=pretend.stub(disallow_oidc=lambda *a: False), ) - response = views.mint_token_from_oidc(request) + response = views.mint_token(oidc_service, dummy_oidc_jwt, request) assert request.response.status == 422 assert response == { "message": "Token request failed", @@ -155,13 +300,10 @@ def test_mint_token_from_trusted_publisher_verify_jwt_signature_fails(): ], } - assert request.find_service.calls == [ - pretend.call(IOIDCPublisherService, name="github") - ] - assert oidc_service.verify_jwt_signature.calls == [pretend.call("faketoken")] + assert oidc_service.verify_jwt_signature.calls == [pretend.call(dummy_oidc_jwt)] -def test_mint_token_from_trusted_publisher_lookup_fails(): +def test_mint_token_trusted_publisher_lookup_fails(dummy_oidc_jwt): claims = pretend.stub() message = "some message" oidc_service = pretend.stub( @@ -172,12 +314,11 @@ def test_mint_token_from_trusted_publisher_lookup_fails(): ) request = pretend.stub( response=pretend.stub(status=None), - body=json.dumps({"token": "faketoken"}), find_service=pretend.call_recorder(lambda cls, **kw: oidc_service), flags=pretend.stub(disallow_oidc=lambda *a: False), ) - response = views.mint_token_from_oidc(request) + response = views.mint_token(oidc_service, dummy_oidc_jwt, request) assert request.response.status == 422 assert response == { "message": "Token request failed", @@ -191,22 +332,22 @@ def test_mint_token_from_trusted_publisher_lookup_fails(): ], } - assert request.find_service.calls == [ - pretend.call(IOIDCPublisherService, name="github"), - ] - assert oidc_service.verify_jwt_signature.calls == [pretend.call("faketoken")] + assert oidc_service.verify_jwt_signature.calls == [pretend.call(dummy_oidc_jwt)] assert oidc_service.find_publisher.calls == [ pretend.call(claims, pending=True), pretend.call(claims, pending=False), ] -def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_request): +def test_mint_token_pending_publisher_project_already_exists( + db_request, dummy_oidc_jwt +): project = ProjectFactory.create() - pending_publisher = PendingGitHubPublisherFactory.create(project_name=project.name) + pending_publisher = PendingGitHubPublisherFactory.create( + project_name=project.name, + ) db_request.flags.disallow_oidc = lambda f=None: False - db_request.body = json.dumps({"token": "faketoken"}) claims = pretend.stub() oidc_service = pretend.stub( @@ -217,7 +358,7 @@ def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_reques ) db_request.find_service = pretend.call_recorder(lambda *a, **kw: oidc_service) - resp = views.mint_token_from_oidc(db_request) + resp = views.mint_token(oidc_service, dummy_oidc_jwt, db_request) assert db_request.response.status_code == 422 assert resp == { "message": "Token request failed", @@ -229,16 +370,14 @@ def test_mint_token_from_oidc_pending_publisher_project_already_exists(db_reques ], } - assert oidc_service.verify_jwt_signature.calls == [pretend.call("faketoken")] + assert oidc_service.verify_jwt_signature.calls == [pretend.call(dummy_oidc_jwt)] assert oidc_service.find_publisher.calls == [pretend.call(claims, pending=True)] - assert db_request.find_service.calls == [ - pretend.call(IOIDCPublisherService, name="github") - ] def test_mint_token_from_oidc_pending_publisher_ok( monkeypatch, db_request, + dummy_oidc_payload, ): user = UserFactory.create() pending_publisher = PendingGitHubPublisherFactory.create( @@ -252,25 +391,7 @@ def test_mint_token_from_oidc_pending_publisher_ok( ) db_request.flags.disallow_oidc = lambda f=None: False - db_request.body = json.dumps( - { - "token": ( - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI2ZTY3YjFjYi0yYjhkLTRi" - "ZTUtOTFjYi03NTdlZGIyZWM5NzAiLCJzdWIiOiJyZXBvOmZvby9iYXIiLCJhdWQiOiJwe" - "XBpIiwicmVmIjoiZmFrZSIsInNoYSI6ImZha2UiLCJyZXBvc2l0b3J5IjoiZm9vL2Jhci" - "IsInJlcG9zaXRvcnlfb3duZXIiOiJmb28iLCJyZXBvc2l0b3J5X293bmVyX2lkIjoiMTI" - "zIiwicnVuX2lkIjoiZmFrZSIsInJ1bl9udW1iZXIiOiJmYWtlIiwicnVuX2F0dGVtcHQi" - "OiIxIiwicmVwb3NpdG9yeV9pZCI6ImZha2UiLCJhY3Rvcl9pZCI6ImZha2UiLCJhY3Rvc" - "iI6ImZvbyIsIndvcmtmbG93IjoiZmFrZSIsImhlYWRfcmVmIjoiZmFrZSIsImJhc2Vfcm" - "VmIjoiZmFrZSIsImV2ZW50X25hbWUiOiJmYWtlIiwicmVmX3R5cGUiOiJmYWtlIiwiZW5" - "2aXJvbm1lbnQiOiJmYWtlIiwiam9iX3dvcmtmbG93X3JlZiI6ImZvby9iYXIvLmdpdGh1" - "Yi93b3JrZmxvd3MvZXhhbXBsZS55bWxAZmFrZSIsImlzcyI6Imh0dHBzOi8vdG9rZW4uY" - "WN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20iLCJuYmYiOjE2NTA2NjMyNjUsImV4cC" - "I6MTY1MDY2NDE2NSwiaWF0IjoxNjUwNjYzODY1fQ.f-FMv5FF5sdxAWeUilYDt9NoE7Et" - "0vbdNhK32c2oC-E" - ) - } - ) + db_request.body = dummy_oidc_payload db_request.remote_addr = "0.0.0.0" ratelimiter = pretend.stub(clear=pretend.call_recorder(lambda id: None)) @@ -291,7 +412,7 @@ def test_mint_token_from_oidc_pending_publisher_ok( def test_mint_token_from_pending_trusted_publisher_invalidates_others( - monkeypatch, db_request + monkeypatch, db_request, dummy_oidc_payload ): time = pretend.stub(time=pretend.call_recorder(lambda: 0)) monkeypatch.setattr(views, "time", time) @@ -329,25 +450,7 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others( ) db_request.flags.oidc_enabled = lambda f: False - db_request.body = json.dumps( - { - "token": ( - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI2ZTY3YjFjYi0yYjhkLTRi" - "ZTUtOTFjYi03NTdlZGIyZWM5NzAiLCJzdWIiOiJyZXBvOmZvby9iYXIiLCJhdWQiOiJwe" - "XBpIiwicmVmIjoiZmFrZSIsInNoYSI6ImZha2UiLCJyZXBvc2l0b3J5IjoiZm9vL2Jhci" - "IsInJlcG9zaXRvcnlfb3duZXIiOiJmb28iLCJyZXBvc2l0b3J5X293bmVyX2lkIjoiMTI" - "zIiwicnVuX2lkIjoiZmFrZSIsInJ1bl9udW1iZXIiOiJmYWtlIiwicnVuX2F0dGVtcHQi" - "OiIxIiwicmVwb3NpdG9yeV9pZCI6ImZha2UiLCJhY3Rvcl9pZCI6ImZha2UiLCJhY3Rvc" - "iI6ImZvbyIsIndvcmtmbG93IjoiZmFrZSIsImhlYWRfcmVmIjoiZmFrZSIsImJhc2Vfcm" - "VmIjoiZmFrZSIsImV2ZW50X25hbWUiOiJmYWtlIiwicmVmX3R5cGUiOiJmYWtlIiwiZW5" - "2aXJvbm1lbnQiOiJmYWtlIiwiam9iX3dvcmtmbG93X3JlZiI6ImZvby9iYXIvLmdpdGh1" - "Yi93b3JrZmxvd3MvZXhhbXBsZS55bWxAZmFrZSIsImlzcyI6Imh0dHBzOi8vdG9rZW4uY" - "WN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20iLCJuYmYiOjE2NTA2NjMyNjUsImV4cC" - "I6MTY1MDY2NDE2NSwiaWF0IjoxNjUwNjYzODY1fQ.f-FMv5FF5sdxAWeUilYDt9NoE7Et" - "0vbdNhK32c2oC-E" - ) - } - ) + db_request.body = dummy_oidc_payload db_request.remote_addr = "0.0.0.0" ratelimiter = pretend.stub(clear=pretend.call_recorder(lambda id: None)) @@ -383,8 +486,8 @@ def test_mint_token_from_pending_trusted_publisher_invalidates_others( ({"sha": "somesha"}, {"ref": None, "sha": "somesha"}), ], ) -def test_mint_token_from_oidc_no_pending_publisher_ok( - monkeypatch, claims_in_token, claims_input +def test_mint_token_no_pending_publisher_ok( + monkeypatch, db_request, claims_in_token, claims_input, dummy_oidc_jwt ): time = pretend.stub(time=pretend.call_recorder(lambda: 0)) monkeypatch.setattr(views, "time", time) @@ -393,18 +496,16 @@ def test_mint_token_from_oidc_no_pending_publisher_ok( id="fakeprojectid", record_event=pretend.call_recorder(lambda **kw: None), ) - publisher = pretend.stub( - id="fakepublisherid", - projects=[project], - publisher_name="fakepublishername", - publisher_url=lambda x=None: "https://fake/url", - ) + + publisher = GitHubPublisherFactory() + monkeypatch.setattr(publisher.__class__, "projects", [project]) + publisher.publisher_url = pretend.call_recorder(lambda **kw: "https://fake/url") # NOTE: Can't set __str__ using pretend.stub() monkeypatch.setattr(publisher.__class__, "__str__", lambda s: "fakespecifier") def _find_publisher(claims, pending=False): if pending: - raise errors.InvalidPublisherError + return None else: return publisher @@ -427,48 +528,43 @@ def find_service(iface, **kw): return macaroon_service assert False, iface - request = pretend.stub( - response=pretend.stub(status=None), - body=json.dumps({"token": "faketoken"}), - find_service=find_service, - domain="fakedomain", - remote_addr="0.0.0.0", - flags=pretend.stub(disallow_oidc=lambda *a: False), - ) + monkeypatch.setattr(db_request, "find_service", find_service) + monkeypatch.setattr(db_request, "domain", "fakedomain") - response = views.mint_token_from_oidc(request) + response = views.mint_token(oidc_service, dummy_oidc_jwt, db_request) assert response == { "success": True, "token": "raw-macaroon", } - assert oidc_service.verify_jwt_signature.calls == [pretend.call("faketoken")] + assert oidc_service.verify_jwt_signature.calls == [pretend.call(dummy_oidc_jwt)] assert oidc_service.find_publisher.calls == [ pretend.call(claims_in_token, pending=True), pretend.call(claims_in_token, pending=False), ] + assert macaroon_service.create_macaroon.calls == [ pretend.call( "fakedomain", f"OpenID token: fakespecifier ({datetime.fromtimestamp(0).isoformat()})", [ caveats.OIDCPublisher( - oidc_publisher_id="fakepublisherid", + oidc_publisher_id=str(publisher.id), ), caveats.ProjectID(project_ids=["fakeprojectid"]), caveats.Expiration(expires_at=900, not_before=0), ], - oidc_publisher_id="fakepublisherid", + oidc_publisher_id=str(publisher.id), additional={"oidc": claims_input}, ) ] assert project.record_event.calls == [ pretend.call( tag=EventTag.Project.ShortLivedAPITokenAdded, - request=request, + request=db_request, additional={ "expires": 900, - "publisher_name": "fakepublishername", + "publisher_name": "GitHub", "publisher_url": "https://fake/url", }, ) diff --git a/warehouse/oidc/__init__.py b/warehouse/oidc/__init__.py index 15452ce8871c..8fddb01d1696 100644 --- a/warehouse/oidc/__init__.py +++ b/warehouse/oidc/__init__.py @@ -47,6 +47,7 @@ def includeme(config): auth = config.get_settings().get("auth.domain") config.add_route("oidc.audience", "/_/oidc/audience", domain=auth) + config.add_route("oidc.mint_token", "/_/oidc/mint-token", domain=auth) config.add_route("oidc.github.mint_token", "/_/oidc/github/mint-token", domain=auth) # Compute OIDC metrics periodically diff --git a/warehouse/oidc/models/_core.py b/warehouse/oidc/models/_core.py index 8330449a4f4f..63765ccaa72e 100644 --- a/warehouse/oidc/models/_core.py +++ b/warehouse/oidc/models/_core.py @@ -232,7 +232,9 @@ def publisher_name(self) -> str: # pragma: no cover # Only concrete subclasses are constructed. raise NotImplementedError - def publisher_url(self, claims=None) -> str | None: # pragma: no cover + def publisher_url( + self, claims: SignedClaims | None = None + ) -> str | None: # pragma: no cover """ NOTE: This is **NOT** a `@property` because we pass `claims` to it. When calling, make sure to use `publisher_url()` diff --git a/warehouse/oidc/services.py b/warehouse/oidc/services.py index c5635dc7c176..d3921d4fc30d 100644 --- a/warehouse/oidc/services.py +++ b/warehouse/oidc/services.py @@ -282,7 +282,8 @@ def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None: def find_publisher( self, signed_claims: SignedClaims, *, pending: bool = False - ) -> OIDCPublisher | PendingOIDCPublisher | None: + ) -> OIDCPublisher | PendingOIDCPublisher: + """Returns a publisher for the given claims, or raises an error.""" metrics_tags = [f"publisher:{self.publisher}"] self.metrics.increment( "warehouse.oidc.find_publisher.attempt", @@ -306,7 +307,7 @@ def find_publisher( ) raise e - def reify_pending_publisher(self, pending_publisher, project): + def reify_pending_publisher(self, pending_publisher, project) -> OIDCPublisher: new_publisher = pending_publisher.reify(self.db) project.oidc_publishers.append(new_publisher) return new_publisher diff --git a/warehouse/oidc/utils.py b/warehouse/oidc/utils.py index 8cdef771214a..b66a60a73988 100644 --- a/warehouse/oidc/utils.py +++ b/warehouse/oidc/utils.py @@ -16,6 +16,7 @@ from pyramid.authorization import Authenticated +from warehouse.admin.flags import AdminFlagValue from warehouse.oidc.errors import InvalidPublisherError from warehouse.oidc.interfaces import SignedClaims from warehouse.oidc.models import ( @@ -30,6 +31,16 @@ GITHUB_OIDC_ISSUER_URL = "https://token.actions.githubusercontent.com" GOOGLE_OIDC_ISSUER_URL = "https://accounts.google.com" +OIDC_ISSUER_SERVICE_NAMES = { + GITHUB_OIDC_ISSUER_URL: "github", + GOOGLE_OIDC_ISSUER_URL: "google", +} + +OIDC_ISSUER_ADMIN_FLAGS = { + GITHUB_OIDC_ISSUER_URL: AdminFlagValue.DISALLOW_GITHUB_OIDC, + GOOGLE_OIDC_ISSUER_URL: AdminFlagValue.DISALLOW_GOOGLE_OIDC, +} + OIDC_ISSUER_URLS = {GITHUB_OIDC_ISSUER_URL, GOOGLE_OIDC_ISSUER_URL} OIDC_PUBLISHER_CLASSES: dict[str, dict[bool, type[OIDCPublisherMixin]]] = { diff --git a/warehouse/oidc/views.py b/warehouse/oidc/views.py index 49df4cdc6828..f047868b5245 100644 --- a/warehouse/oidc/views.py +++ b/warehouse/oidc/views.py @@ -13,27 +13,48 @@ import time from datetime import datetime +from typing import TypedDict + +import jwt +import sentry_sdk from pydantic import BaseModel, StrictStr, ValidationError +from pyramid.request import Request from pyramid.response import Response from pyramid.view import view_config -from warehouse.admin.flags import AdminFlagValue from warehouse.events.tags import EventTag from warehouse.macaroons import caveats from warehouse.macaroons.interfaces import IMacaroonService +from warehouse.macaroons.services import DatabaseMacaroonService +from warehouse.metrics.interfaces import IMetricsService from warehouse.oidc.errors import InvalidPublisherError from warehouse.oidc.interfaces import IOIDCPublisherService +from warehouse.oidc.models import OIDCPublisher, PendingOIDCPublisher +from warehouse.oidc.services import OIDCPublisherService +from warehouse.oidc.utils import OIDC_ISSUER_ADMIN_FLAGS, OIDC_ISSUER_SERVICE_NAMES from warehouse.packaging.interfaces import IProjectService from warehouse.packaging.models import ProjectFactory from warehouse.rate_limiting.interfaces import IRateLimiter +class Error(TypedDict): + code: str + description: str + + +class JsonResponse(TypedDict, total=False): + message: str | None + errors: list[Error] | None + token: StrictStr | None + success: bool | None + + class TokenPayload(BaseModel): token: StrictStr -def _ratelimiters(request): +def _ratelimiters(request: Request) -> dict[str, IRateLimiter]: return { "user.oidc": request.find_service( IRateLimiter, name="user_oidc.publisher.register" @@ -44,6 +65,15 @@ def _ratelimiters(request): } +def _invalid(errors: list[Error], request: Request) -> JsonResponse: + request.response.status = 422 + + return { + "message": "Token request failed", + "errors": errors, + } + + @view_config( route_name="oidc.audience", require_methods=["GET"], @@ -51,13 +81,13 @@ def _ratelimiters(request): require_csrf=False, has_translations=False, ) -def oidc_audience(request): +def oidc_audience(request: Request): if request.flags.disallow_oidc(): return Response( status=403, json={"message": "Trusted publishing functionality not enabled"} ) - audience = request.registry.settings["warehouse.oidc.audience"] + audience: str = request.registry.settings["warehouse.oidc.audience"] return {"audience": audience} @@ -66,41 +96,89 @@ def oidc_audience(request): require_methods=["POST"], renderer="json", require_csrf=False, - has_translations=True, ) -def mint_token_from_oidc(request): - def _invalid(errors): - request.response.status = 422 - return {"message": "Token request failed", "errors": errors} +@view_config( + route_name="oidc.mint_token", + require_methods=["POST"], + renderer="json", + require_csrf=False, +) +def mint_token_from_oidc(request: Request): + try: + payload = TokenPayload.model_validate_json(request.body) + unverified_jwt = payload.token + except ValidationError as exc: + return _invalid( + errors=[{"code": "invalid-payload", "description": str(exc)}], + request=request, + ) + + # We currently have an **unverified** JWT. To verify it, we need to + # know which OIDC service's keyring to check it against. + # To do this, we gingerly peek into the unverified claims and + # use the `iss` to key into the right `OIDCPublisherService`. + try: + unverified_claims = jwt.decode( + unverified_jwt, options=dict(verify_signature=False) + ) + unverified_issuer: str = unverified_claims["iss"] + except Exception as e: + metrics = request.find_service(IMetricsService, context=None) + metrics.increment("warehouse.oidc.mint_token_from_oidc.malformed_jwt") + + # We expect only PyJWTError and KeyError; anything else indicates + # an abstraction leak in jwt that we'll log for upstream reporting. + if not isinstance(e, (jwt.PyJWTError, KeyError)): + with sentry_sdk.push_scope() as scope: + scope.fingerprint = e + sentry_sdk.capture_message(f"jwt.decode raised generic error: {e}") - if request.flags.disallow_oidc(AdminFlagValue.DISALLOW_GITHUB_OIDC): + return _invalid( + errors=[{"code": "invalid-payload", "description": "malformed JWT"}], + request=request, + ) + + # Associate the given issuer claim with Warehouse's OIDCPublisherService. + service_name = OIDC_ISSUER_SERVICE_NAMES.get(unverified_issuer) + if not service_name: + return _invalid( + errors=[ + { + "code": "invalid-payload", + "description": "unknown trusted publishing issuer", + } + ], + request=request, + ) + + if request.flags.disallow_oidc(OIDC_ISSUER_ADMIN_FLAGS[unverified_issuer]): return _invalid( errors=[ { "code": "not-enabled", - "description": ( - "GitHub-based trusted publishing functionality not enabled" - ), + "description": f"{service_name} trusted publishing functionality not enabled", # noqa } - ] + ], + request=request, ) - try: - payload = TokenPayload.model_validate_json(request.body) - unverified_jwt = payload.token - except ValidationError as exc: - return _invalid(errors=[{"code": "invalid-payload", "description": str(exc)}]) + oidc_service: OIDCPublisherService = request.find_service( + IOIDCPublisherService, name=service_name + ) - # For the time being, GitHub is our only OIDC publisher. - # In the future, this should locate the correct service based on an - # identifier in the request body. - oidc_service = request.find_service(IOIDCPublisherService, name="github") + return mint_token(oidc_service, unverified_jwt, request) + + +def mint_token( + oidc_service: OIDCPublisherService, unverified_jwt: str, request: Request +) -> JsonResponse: claims = oidc_service.verify_jwt_signature(unverified_jwt) if not claims: return _invalid( errors=[ {"code": "invalid-token", "description": "malformed or invalid token"} - ] + ], + request=request, ) # First, try to find a pending publisher. @@ -108,41 +186,39 @@ def _invalid(errors): pending_publisher = oidc_service.find_publisher(claims, pending=True) factory = ProjectFactory(request) - # If the project already exists, this pending publisher is no longer - # valid and needs to be removed. - # NOTE: This is mostly a sanity check, since we dispose of invalidated - # pending publishers below. - if pending_publisher.project_name in factory: - request.db.delete(pending_publisher) - return _invalid( - errors=[ - { - "code": "invalid-pending-publisher", - "description": "valid token, but project already exists", - } - ] - ) - - # Create the new project - project_service = request.find_service(IProjectService) - new_project = project_service.create_project( - pending_publisher.project_name, - pending_publisher.added_by, - request, - ratelimited=False, - ) + if isinstance(pending_publisher, PendingOIDCPublisher): + # If the project already exists, this pending publisher is no longer + # valid and needs to be removed. + # NOTE: This is mostly a sanity check, since we dispose of invalidated + # pending publishers below. + if pending_publisher.project_name in factory: + request.db.delete(pending_publisher) + return _invalid( + errors=[ + { + "code": "invalid-pending-publisher", + "description": "valid token, but project already exists", + } + ], + request=request, + ) - # Creating the project will remove all pending publishers, EXCEPT the - # pending publisher created by the uploader. If such a pending - # publisher exists, reify it against the newly created project. - oidc_service.reify_pending_publisher(pending_publisher, new_project) + # Create the new project, and reify the pending publisher against it. + project_service = request.find_service(IProjectService) + new_project = project_service.create_project( + pending_publisher.project_name, + pending_publisher.added_by, + request, + ratelimited=False, + ) - # Successfully converting a pending publisher into a normal publisher - # is a positive signal, so we reset the associated ratelimits. - ratelimiters = _ratelimiters(request) - ratelimiters["user.oidc"].clear(pending_publisher.added_by.id) - ratelimiters["ip.oidc"].clear(request.remote_addr) + oidc_service.reify_pending_publisher(pending_publisher, new_project) + # Successfully converting a pending publisher into a normal publisher + # is a positive signal, so we reset the associated ratelimits. + ratelimiters = _ratelimiters(request) + ratelimiters["user.oidc"].clear(pending_publisher.added_by.id) + ratelimiters["ip.oidc"].clear(request.remote_addr) except InvalidPublisherError: # If the claim set isn't valid for a pending publisher, it's OK, we # will try finding a regular publisher @@ -153,6 +229,8 @@ def _invalid(errors): # to actually do the macaroon minting with. try: publisher = oidc_service.find_publisher(claims, pending=False) + # NOTE: assert to persuade mypy of the correct type here. + assert isinstance(publisher, OIDCPublisher) except InvalidPublisherError as e: return _invalid( errors=[ @@ -160,14 +238,17 @@ def _invalid(errors): "code": "invalid-publisher", "description": f"valid token, but no corresponding publisher ({e})", } - ] + ], + request=request, ) # At this point, we've verified that the given JWT is valid for the given # project. All we need to do is mint a new token. # NOTE: For OIDC-minted API tokens, the Macaroon's description string # is purely an implementation detail and is not displayed to the user. - macaroon_service = request.find_service(IMacaroonService, context=None) + macaroon_service: DatabaseMacaroonService = request.find_service( + IMacaroonService, context=None + ) not_before = int(time.time()) expires_at = not_before + 900 serialized, dm = macaroon_service.create_macaroon( @@ -183,7 +264,7 @@ def _invalid(errors): caveats.ProjectID(project_ids=[str(p.id) for p in publisher.projects]), caveats.Expiration(expires_at=expires_at, not_before=not_before), ], - oidc_publisher_id=publisher.id, + oidc_publisher_id=str(publisher.id), additional={"oidc": {"ref": claims.get("ref"), "sha": claims.get("sha")}}, ) for project in publisher.projects: