diff --git a/src/fastmcp/server/auth/oauth_proxy/proxy.py b/src/fastmcp/server/auth/oauth_proxy/proxy.py index 4938932f82..e1a24720f5 100644 --- a/src/fastmcp/server/auth/oauth_proxy/proxy.py +++ b/src/fastmcp/server/auth/oauth_proxy/proxy.py @@ -978,6 +978,20 @@ async def exchange_authorization_code( # Refresh Token Flow # ------------------------------------------------------------------------- + def _prepare_scopes_for_token_exchange(self, scopes: list[str]) -> list[str]: + """Prepare scopes for initial token exchange (auth code -> tokens). + + Override this method to provide scopes during the authorization + code exchange. Some providers (like Azure) require scopes to be sent. + + Args: + scopes: Scopes from the authorization request + + Returns: + List of scopes to send, or empty list to omit scope parameter + """ + return scopes + def _prepare_scopes_for_upstream_refresh(self, scopes: list[str]) -> list[str]: """Prepare scopes for upstream token refresh request. @@ -1532,6 +1546,13 @@ async def _handle_idp_callback( txn_id, ) + # Allow providers to specify scope for token exchange + exchange_scopes = self._prepare_scopes_for_token_exchange( + transaction.get("scopes") or [] + ) + if exchange_scopes: + token_params["scope"] = " ".join(exchange_scopes) + # Add any extra token parameters configured for this proxy if self._extra_token_params: token_params.update(self._extra_token_params) diff --git a/src/fastmcp/server/auth/providers/azure.py b/src/fastmcp/server/auth/providers/azure.py index e57b9c04a3..07ece928df 100644 --- a/src/fastmcp/server/auth/providers/azure.py +++ b/src/fastmcp/server/auth/providers/azure.py @@ -193,6 +193,8 @@ def __init__( token_endpoint = f"https://{base_authority}/{tenant_id}/oauth2/v2.0/token" # Initialize OAuth proxy with Azure endpoints + # Remember there's hooks called, such as _prepare_scopes_for_token_exchange + # and _prepare_scopes_for_upstream_refresh super().__init__( upstream_authorization_endpoint=authorization_endpoint, upstream_token_endpoint=token_endpoint, @@ -206,7 +208,6 @@ def __init__( client_storage=client_storage, jwt_signing_key=jwt_signing_key, require_authorization_consent=require_authorization_consent, - # Advertise full scopes including OIDC (even though we only validate non-OIDC) valid_scopes=parsed_required_scopes, ) @@ -318,16 +319,37 @@ def _build_upstream_authorize_url( # Let parent build the URL with prefixed scopes return super()._build_upstream_authorize_url(txn_id, modified_transaction) + def _prepare_scopes_for_token_exchange(self, scopes: list[str]) -> list[str]: + """Prepare scopes for Azure authorization code exchange. + + Azure requires scopes during token exchange (AADSTS28003 error if missing). + Azure only allows ONE resource per token request (AADSTS28000), so we only + include scopes for this API plus OIDC scopes. + + Args: + scopes: Scopes from the authorization request (unprefixed) + + Returns: + List of scopes for Azure token endpoint + """ + # Prefix scopes for this API + prefixed_scopes = self._prefix_scopes_for_azure(scopes or []) + + # Add OIDC scopes only (not other API scopes) to avoid AADSTS28000 + if self.additional_authorize_scopes: + prefixed_scopes.extend( + s for s in self.additional_authorize_scopes if s in OIDC_SCOPES + ) + + deduplicated = list(dict.fromkeys(prefixed_scopes)) + logger.debug("Token exchange scopes: %s", deduplicated) + return deduplicated + def _prepare_scopes_for_upstream_refresh(self, scopes: list[str]) -> list[str]: """Prepare scopes for Azure token refresh. - Azure requires: - 1. Fully-qualified custom scopes (e.g., "api://xxx/read" not "read") - 2. Microsoft Graph scopes (e.g., "User.Read", "openid") sent as-is - 3. Additional scopes from provider config (additional_authorize_scopes) - - This method transforms base client scopes for Azure while keeping them - unprefixed in storage to prevent accumulation. + Azure requires fully-qualified scopes and only allows ONE resource per + token request (AADSTS28000). We include scopes for this API plus OIDC scopes. Args: scopes: Base scopes from RefreshToken (unprefixed, e.g., ["read"]) @@ -338,22 +360,19 @@ def _prepare_scopes_for_upstream_refresh(self, scopes: list[str]) -> list[str]: logger.debug("Base scopes from storage: %s", scopes) # Filter out any additional_authorize_scopes that may have been stored - # (they shouldn't be in storage, but clean them up if they are) additional_scopes_set = set(self.additional_authorize_scopes or []) base_scopes = [s for s in scopes if s not in additional_scopes_set] - # Prefix base scopes with identifier_uri for Azure using shared helper + # Prefix base scopes with identifier_uri for Azure prefixed_scopes = self._prefix_scopes_for_azure(base_scopes) - # Add additional scopes (Graph + OIDC) for the Azure request - # These are NOT stored in RefreshToken, only sent to Azure + # Add OIDC scopes only (not other API scopes) to avoid AADSTS28000 if self.additional_authorize_scopes: - prefixed_scopes.extend(self.additional_authorize_scopes) + prefixed_scopes.extend( + s for s in self.additional_authorize_scopes if s in OIDC_SCOPES + ) - # Deduplicate while preserving order (in case older tokens have duplicates) - # Use dict.fromkeys() for O(n) deduplication with order preservation deduplicated_scopes = list(dict.fromkeys(prefixed_scopes)) - logger.debug("Scopes for Azure token endpoint: %s", deduplicated_scopes) return deduplicated_scopes diff --git a/tests/server/auth/providers/test_azure.py b/tests/server/auth/providers/test_azure.py index e3fb13931a..fbe52044cb 100644 --- a/tests/server/auth/providers/test_azure.py +++ b/tests/server/auth/providers/test_azure.py @@ -438,7 +438,11 @@ def test_prepare_scopes_for_upstream_refresh_already_prefixed(self): assert len(result) == 2 def test_prepare_scopes_for_upstream_refresh_with_additional_scopes(self): - """Test that additional_authorize_scopes are added during token refresh.""" + """Test that only OIDC scopes from additional_authorize_scopes are added. + + Azure only allows ONE resource per token request (AADSTS28000), so + non-OIDC scopes like User.Read are excluded from refresh requests. + """ provider = AzureProvider( client_id="test_client", client_secret="test_secret", @@ -447,7 +451,7 @@ def test_prepare_scopes_for_upstream_refresh_with_additional_scopes(self): identifier_uri="api://my-api", required_scopes=["read"], additional_authorize_scopes=[ - "User.Read", + "User.Read", # Not OIDC - excluded "openid", "profile", "offline_access", @@ -455,16 +459,16 @@ def test_prepare_scopes_for_upstream_refresh_with_additional_scopes(self): jwt_signing_key="test-secret", ) - # Base scopes should be prefixed, additional scopes appended + # Base scopes should be prefixed, only OIDC scopes appended result = provider._prepare_scopes_for_upstream_refresh(["read", "write"]) assert "api://my-api/read" in result assert "api://my-api/write" in result - assert "User.Read" in result + assert "User.Read" not in result # Not OIDC, excluded assert "openid" in result assert "profile" in result assert "offline_access" in result - assert len(result) == 6 + assert len(result) == 5 def test_prepare_scopes_for_upstream_refresh_filters_duplicate_additional_scopes( self, @@ -482,16 +486,17 @@ def test_prepare_scopes_for_upstream_refresh_filters_duplicate_additional_scopes ) # If additional scopes were accidentally stored, they should be filtered - # to prevent accumulation + # User.Read is not OIDC so won't be added result = provider._prepare_scopes_for_upstream_refresh( ["read", "User.Read", "openid"] ) - # Should have: api://my-api/read (prefixed) + User.Read + openid (added once) + # Should have: api://my-api/read (prefixed) + openid (OIDC, added once) + # User.Read is filtered from storage AND not added (not OIDC) assert "api://my-api/read" in result - assert result.count("User.Read") == 1 + assert "User.Read" not in result # Not OIDC assert result.count("openid") == 1 - assert len(result) == 3 + assert len(result) == 2 def test_prepare_scopes_for_upstream_refresh_mixed_scopes(self): """Test mixed scenario with both prefixed and unprefixed scopes.""" @@ -502,7 +507,7 @@ def test_prepare_scopes_for_upstream_refresh_mixed_scopes(self): base_url="https://myserver.com", identifier_uri="api://my-api", required_scopes=["read"], - additional_authorize_scopes=["User.Read"], + additional_authorize_scopes=["openid"], # OIDC scope jwt_signing_key="test-secret", ) @@ -514,7 +519,7 @@ def test_prepare_scopes_for_upstream_refresh_mixed_scopes(self): assert "api://my-api/read" in result assert "api://other-api/admin" in result # Already prefixed, unchanged assert "api://my-api/write" in result - assert "User.Read" in result + assert "openid" in result assert len(result) == 4 def test_prepare_scopes_for_upstream_refresh_scope_with_slash(self): @@ -552,12 +557,12 @@ def test_prepare_scopes_for_upstream_refresh_empty_scopes(self): jwt_signing_key="test-secret", ) - # Empty scopes should still add additional_authorize_scopes + # Empty scopes should still add OIDC scopes (not User.Read) result = provider._prepare_scopes_for_upstream_refresh([]) - assert "User.Read" in result + assert "User.Read" not in result # Not OIDC assert "openid" in result - assert len(result) == 2 + assert len(result) == 1 # Only openid (the only OIDC scope) def test_prepare_scopes_for_upstream_refresh_no_additional_scopes(self): """Test behavior when no additional_authorize_scopes are configured.""" @@ -587,21 +592,21 @@ def test_prepare_scopes_for_upstream_refresh_deduplicates_scopes(self): base_url="https://myserver.com", identifier_uri="api://my-api", required_scopes=["read"], - additional_authorize_scopes=["User.Read", "openid"], + additional_authorize_scopes=["openid", "profile"], # OIDC scopes only jwt_signing_key="test-secret", ) - # Test with duplicate base scopes and duplicate additional scopes + # Test with duplicate base scopes result = provider._prepare_scopes_for_upstream_refresh( - ["read", "write", "read", "User.Read", "openid"] + ["read", "write", "read", "openid"] ) - # Should have deduplicated results in order + # Should have deduplicated results in order (User.Read filtered, openid added once) assert result == [ "api://my-api/read", "api://my-api/write", - "User.Read", "openid", + "profile", ] assert len(result) == 4 @@ -809,242 +814,134 @@ def test_prepare_scopes_for_refresh_handles_oidc_scopes(self): assert "api://my-api/profile" not in result -class TestAzureExtractUpstreamClaims: - """Tests for Azure provider's _extract_upstream_claims method.""" - - @staticmethod - def create_test_jwt(claims: dict) -> str: - """Create a test JWT token with the given claims.""" - import base64 - import json - - header = base64.urlsafe_b64encode( - json.dumps({"alg": "RS256", "typ": "JWT"}).encode() - ).rstrip(b"=") - payload = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=") - signature = base64.urlsafe_b64encode(b"fake-signature").rstrip(b"=") - return f"{header.decode()}.{payload.decode()}.{signature.decode()}" +class TestAzureTokenExchangeScopes: + """Tests for Azure provider's token exchange scope handling. - async def test_extract_claims_from_azure_jwt(self): - """Test that Azure identity claims are extracted from access token.""" - provider = AzureProvider( - client_id="test_client", - client_secret="test_secret", - tenant_id="test-tenant", - base_url="https://myserver.com", - required_scopes=["read"], - jwt_signing_key="test-secret", - ) - - azure_jwt = self.create_test_jwt( - { - "sub": "user-subject-id", - "oid": "user-object-id", - "tid": "tenant-id-123", - "azp": "client-app-id", - "name": "Test User", - "given_name": "Test", - "family_name": "User", - "preferred_username": "testuser@example.com", - "upn": "testuser@example.com", - "email": "test@example.com", - "roles": ["Admin", "Reader"], - "groups": ["group-1", "group-2"], - "exp": 9999999999, - "iat": 1234567890, - "iss": "https://login.microsoftonline.com/test-tenant/v2.0", - } - ) - - idp_tokens = { - "access_token": azure_jwt, - "token_type": "Bearer", - "expires_in": 3600, - } - - claims = await provider._extract_upstream_claims(idp_tokens) - - assert claims is not None - assert claims["sub"] == "user-subject-id" - assert claims["oid"] == "user-object-id" - assert claims["tid"] == "tenant-id-123" - assert claims["azp"] == "client-app-id" - assert claims["name"] == "Test User" - assert claims["given_name"] == "Test" - assert claims["family_name"] == "User" - assert claims["preferred_username"] == "testuser@example.com" - assert claims["upn"] == "testuser@example.com" - assert claims["email"] == "test@example.com" - assert claims["roles"] == ["Admin", "Reader"] - assert claims["groups"] == ["group-1", "group-2"] - - async def test_extract_claims_only_includes_identity_claims(self): - """Test that only identity claims are extracted, not all JWT claims.""" - provider = AzureProvider( - client_id="test_client", - client_secret="test_secret", - tenant_id="test-tenant", - base_url="https://myserver.com", - required_scopes=["read"], - jwt_signing_key="test-secret", - ) + Azure requires scopes to be sent during the authorization code exchange. + The provider overrides _prepare_scopes_for_token_exchange to return + properly prefixed scopes. + """ - azure_jwt = self.create_test_jwt( - { - "sub": "user-id", - "oid": "object-id", - "name": "Test User", - "exp": 9999999999, - "iat": 1234567890, - "iss": "https://issuer.example.com", - "aud": "test-audience", - "nbf": 1234567890, - "scp": "read write", - "azp": "some-client", - } - ) - - idp_tokens = {"access_token": azure_jwt} - - claims = await provider._extract_upstream_claims(idp_tokens) - - # Only identity claims should be present - assert claims is not None - assert "sub" in claims - assert "oid" in claims - assert "name" in claims - assert "azp" in claims # azp is an identity claim we extract - # Standard JWT claims should NOT be extracted - assert "exp" not in claims - assert "iat" not in claims - assert "iss" not in claims - assert "aud" not in claims - assert "nbf" not in claims - assert "scp" not in claims - - async def test_extract_claims_returns_none_for_missing_access_token(self): - """Test that None is returned when access_token is missing.""" + def test_prepare_scopes_returns_prefixed_scopes(self): + """Test that _prepare_scopes_for_token_exchange returns prefixed scopes.""" provider = AzureProvider( client_id="test_client", client_secret="test_secret", tenant_id="test-tenant", base_url="https://myserver.com", - required_scopes=["read"], + identifier_uri="api://my-api", + required_scopes=["read", "write"], jwt_signing_key="test-secret", ) - idp_tokens = {"token_type": "Bearer", "expires_in": 3600} - - claims = await provider._extract_upstream_claims(idp_tokens) - - assert claims is None + scopes = provider._prepare_scopes_for_token_exchange(["read", "write"]) + assert len(scopes) > 0 + assert "api://my-api/read" in scopes + assert "api://my-api/write" in scopes - async def test_extract_claims_returns_none_for_opaque_token(self): - """Test that None is returned for opaque (non-JWT) tokens.""" + def test_prepare_scopes_includes_additional_oidc_scopes(self): + """Test that _prepare_scopes_for_token_exchange includes OIDC scopes.""" provider = AzureProvider( client_id="test_client", client_secret="test_secret", tenant_id="test-tenant", base_url="https://myserver.com", + identifier_uri="api://my-api", required_scopes=["read"], + additional_authorize_scopes=["openid", "profile", "offline_access"], jwt_signing_key="test-secret", ) - idp_tokens = { - "access_token": "gho_opaque_token_not_a_jwt", # Not a JWT - "token_type": "Bearer", - } - - claims = await provider._extract_upstream_claims(idp_tokens) + scopes = provider._prepare_scopes_for_token_exchange(["read"]) + assert len(scopes) > 0 + assert "api://my-api/read" in scopes + assert "openid" in scopes + assert "profile" in scopes + assert "offline_access" in scopes - assert claims is None + def test_prepare_scopes_excludes_other_api_scopes(self): + """Test token exchange excludes other API scopes (Azure AADSTS28000). - async def test_extract_claims_returns_none_for_malformed_jwt(self): - """Test that None is returned for malformed JWT tokens.""" + Azure only allows ONE resource per token exchange. Other API scopes + are requested during authorization but excluded from token exchange. + """ provider = AzureProvider( - client_id="test_client", + client_id="00000000-1111-2222-3333-444444444444", client_secret="test_secret", tenant_id="test-tenant", base_url="https://myserver.com", - required_scopes=["read"], + required_scopes=["user_impersonation"], + additional_authorize_scopes=[ + "openid", + "profile", + "offline_access", + "api://aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/user_impersonation", + "api://11111111-2222-3333-4444-555555555555/user_impersonation", + ], jwt_signing_key="test-secret", ) - # Only two parts (missing signature) - idp_tokens = {"access_token": "header.payload"} - - claims = await provider._extract_upstream_claims(idp_tokens) + scopes = provider._prepare_scopes_for_token_exchange(["user_impersonation"]) + assert len(scopes) > 0 + # Primary API scope should be prefixed with the provider's identifier_uri + assert "api://00000000-1111-2222-3333-444444444444/user_impersonation" in scopes + # OIDC scopes should be included + assert "openid" in scopes + assert "profile" in scopes + assert "offline_access" in scopes + # Other API scopes should NOT be included (Azure multi-resource limitation) + assert not any("api://aaaaaaaa" in s for s in scopes) + assert not any("api://11111111" in s for s in scopes) - assert claims is None - - async def test_extract_claims_returns_none_for_invalid_base64(self): - """Test that None is returned for JWT with invalid base64.""" + def test_prepare_scopes_deduplicates_scopes(self): + """Test that duplicate scopes are deduplicated.""" provider = AzureProvider( client_id="test_client", client_secret="test_secret", tenant_id="test-tenant", base_url="https://myserver.com", + identifier_uri="api://my-api", required_scopes=["read"], + additional_authorize_scopes=["api://my-api/read", "openid"], jwt_signing_key="test-secret", ) - # Invalid base64 in payload - idp_tokens = {"access_token": "header.not-valid-base64!!!.signature"} + # Pass a scope that will be prefixed to match one in additional_authorize_scopes + scopes = provider._prepare_scopes_for_token_exchange(["read"]) + assert len(scopes) > 0 + # Should be deduplicated - api://my-api/read appears only once + assert scopes.count("api://my-api/read") == 1 + assert "openid" in scopes - claims = await provider._extract_upstream_claims(idp_tokens) + def test_extra_token_params_does_not_contain_scope(self): + """Test that extra_token_params doesn't contain scope to avoid TypeError. - assert claims is None + Previously, Azure provider set extra_token_params={"scope": ...} during init. + This caused a TypeError in exchange_refresh_token because it passes both + scope=... AND **self._extra_token_params, resulting in: + "got multiple values for keyword argument 'scope'" - async def test_extract_claims_returns_none_for_empty_identity_claims(self): - """Test that None is returned when no identity claims are present.""" + The fix uses the _prepare_scopes_for_token_exchange hook instead. + """ provider = AzureProvider( client_id="test_client", client_secret="test_secret", tenant_id="test-tenant", base_url="https://myserver.com", - required_scopes=["read"], + identifier_uri="api://my-api", + required_scopes=["read", "write"], + additional_authorize_scopes=["openid", "profile", "offline_access"], jwt_signing_key="test-secret", ) - # JWT with only standard claims, no identity claims - azure_jwt = self.create_test_jwt( - { - "exp": 9999999999, - "iat": 1234567890, - "iss": "https://issuer.example.com", - "aud": "test-audience", - } - ) - - idp_tokens = {"access_token": azure_jwt} + # extra_token_params should NOT contain "scope" to avoid TypeError during refresh + assert "scope" not in provider._extra_token_params - claims = await provider._extract_upstream_claims(idp_tokens) + # Instead, scopes should be provided via the hook methods + exchange_scopes = provider._prepare_scopes_for_token_exchange(["read", "write"]) + assert len(exchange_scopes) > 0 - assert claims is None - - async def test_extract_claims_partial_identity_claims(self): - """Test extraction when only some identity claims are present.""" - provider = AzureProvider( - client_id="test_client", - client_secret="test_secret", - tenant_id="test-tenant", - base_url="https://myserver.com", - required_scopes=["read"], - jwt_signing_key="test-secret", - ) - - # JWT with only sub and name - azure_jwt = self.create_test_jwt( - { - "sub": "user-id", - "name": "Test User", - "exp": 9999999999, - } + refresh_scopes = provider._prepare_scopes_for_upstream_refresh( + ["read", "write"] ) - - idp_tokens = {"access_token": azure_jwt} - - claims = await provider._extract_upstream_claims(idp_tokens) - - assert claims is not None - assert claims == {"sub": "user-id", "name": "Test User"} + assert len(refresh_scopes) > 0