Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/fastmcp/server/auth/oauth_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,23 @@ async def exchange_authorization_code(
# Refresh Token Flow
# -------------------------------------------------------------------------

def _prepare_scopes_for_upstream_refresh(self, scopes: list[str]) -> list[str]:
"""Prepare scopes for upstream token refresh request.

Override this method to transform scopes before sending to upstream provider.
For example, Azure needs to prefix scopes and add additional Graph scopes.

The scopes parameter represents what should be stored in the RefreshToken.
This method returns what should be sent to the upstream provider.

Args:
scopes: Base scopes that will be stored in RefreshToken

Returns:
Scopes to send to upstream provider (may be transformed/augmented)
"""
return scopes

async def load_refresh_token(
self,
client: OAuthClientInformationFull,
Expand Down Expand Up @@ -1333,12 +1350,17 @@ async def exchange_refresh_token(
timeout=HTTP_TIMEOUT_SECONDS,
)

# Allow child classes to transform scopes before sending to upstream
# This enables provider-specific scope formatting (e.g., Azure prefixing)
# while keeping original scopes in storage
upstream_scopes = self._prepare_scopes_for_upstream_refresh(scopes)

try:
logger.debug("Refreshing upstream token (jti=%s)", refresh_jti[:8])
token_response: dict[str, Any] = await oauth_client.refresh_token( # type: ignore[misc]
url=self._upstream_token_endpoint,
refresh_token=upstream_token_set.refresh_token,
scope=" ".join(scopes) if scopes else None,
scope=" ".join(upstream_scopes) if upstream_scopes else None,
**self._extra_token_params,
)
logger.debug("Successfully refreshed upstream token")
Expand Down
70 changes: 62 additions & 8 deletions src/fastmcp/server/auth/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,28 @@ async def authorize(
separator = "&" if "?" in auth_url else "?"
return f"{auth_url}{separator}prompt=select_account"

def _prefix_scopes_for_azure(self, scopes: list[str]) -> list[str]:
"""Prefix unprefixed scopes with identifier_uri for Azure.

This helper centralizes the scope prefixing logic used in both
authorization and token refresh flows.

Args:
scopes: List of scopes, may be prefixed or unprefixed

Returns:
List of scopes with identifier_uri prefix applied where needed
"""
prefixed = []
for scope in scopes:
if "://" in scope or "/" in scope:
# Already fully-qualified (e.g., "api://xxx/read" or "User.Read")
prefixed.append(scope)
else:
# Unprefixed client scope - prefix with identifier_uri
prefixed.append(f"{self.identifier_uri}/{scope}")
return prefixed

def _build_upstream_authorize_url(
self, txn_id: str, transaction: dict[str, Any]
) -> str:
Expand All @@ -339,14 +361,7 @@ def _build_upstream_authorize_url(
unprefixed_scopes = transaction.get("scopes") or self.required_scopes or []

# Prefix scopes for Azure authorization request
prefixed_scopes = []
for scope in unprefixed_scopes:
if "://" in scope or "/" in scope:
# Already a full URI or path (e.g., "api://xxx/read" or "User.Read")
prefixed_scopes.append(scope)
else:
# Unprefixed scope name - prefix it with identifier_uri
prefixed_scopes.append(f"{self.identifier_uri}/{scope}")
prefixed_scopes = self._prefix_scopes_for_azure(unprefixed_scopes)

# Add Microsoft Graph scopes (not validated, not prefixed)
if self.additional_authorize_scopes:
Expand All @@ -358,3 +373,42 @@ def _build_upstream_authorize_url(

# Let parent build the URL with prefixed scopes
return super()._build_upstream_authorize_url(txn_id, modified_transaction)

def _prepare_scopes_for_upstream_refresh(self, scopes: list[str]) -> list[str]:
"""Prepare scopes for Azure token refresh.

Azure requires:
1. Fully-qualified custom scopes (e.g., "api://xxx/read" not "read")
2. Microsoft Graph scopes (e.g., "User.Read", "openid") sent as-is
3. Additional scopes from provider config (additional_authorize_scopes)

This method transforms base client scopes for Azure while keeping them
unprefixed in storage to prevent accumulation.

Args:
scopes: Base scopes from RefreshToken (unprefixed, e.g., ["read"])

Returns:
Deduplicated list of scopes formatted for Azure token endpoint
"""
logger.debug("Base scopes from storage: %s", scopes)

# Filter out any additional_authorize_scopes that may have been stored
# (they shouldn't be in storage, but clean them up if they are)
additional_scopes_set = set(self.additional_authorize_scopes or [])
base_scopes = [s for s in scopes if s not in additional_scopes_set]

# Prefix base scopes with identifier_uri for Azure using shared helper
prefixed_scopes = self._prefix_scopes_for_azure(base_scopes)

# Add additional scopes (Graph + OIDC) for the Azure request
# These are NOT stored in RefreshToken, only sent to Azure
if self.additional_authorize_scopes:
prefixed_scopes.extend(self.additional_authorize_scopes)

# Deduplicate while preserving order (in case older tokens have duplicates)
# Use dict.fromkeys() for O(n) deduplication with order preservation
deduplicated_scopes = list(dict.fromkeys(prefixed_scopes))

logger.debug("Scopes for Azure token endpoint: %s", deduplicated_scopes)
return deduplicated_scopes
222 changes: 222 additions & 0 deletions tests/server/auth/providers/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,225 @@ def test_base_authority_with_special_tenant_values(self):
parsed = urlparse(provider._upstream_authorization_endpoint)
assert parsed.netloc == "login.microsoftonline.us"
assert "/organizations/" in parsed.path

def test_prepare_scopes_for_upstream_refresh_basic_prefixing(self):
"""Test that unprefixed scopes are correctly prefixed for Azure token refresh."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read", "write"],
jwt_signing_key="test-secret",
)

# Unprefixed scopes from storage should be prefixed
result = provider._prepare_scopes_for_upstream_refresh(["read", "write"])

assert "api://my-api/read" in result
assert "api://my-api/write" in result
assert len(result) == 2

def test_prepare_scopes_for_upstream_refresh_already_prefixed(self):
"""Test that already-prefixed scopes remain unchanged."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
jwt_signing_key="test-secret",
)

# Already prefixed scopes should pass through unchanged
result = provider._prepare_scopes_for_upstream_refresh(
["api://my-api/read", "api://other-api/admin"]
)

assert "api://my-api/read" in result
assert "api://other-api/admin" in result
assert len(result) == 2

def test_prepare_scopes_for_upstream_refresh_with_additional_scopes(self):
"""Test that additional_authorize_scopes are added during token refresh."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
additional_authorize_scopes=[
"User.Read",
"openid",
"profile",
"offline_access",
],
jwt_signing_key="test-secret",
)

# Base scopes should be prefixed, additional scopes appended
result = provider._prepare_scopes_for_upstream_refresh(["read", "write"])

assert "api://my-api/read" in result
assert "api://my-api/write" in result
assert "User.Read" in result
assert "openid" in result
assert "profile" in result
assert "offline_access" in result
assert len(result) == 6

def test_prepare_scopes_for_upstream_refresh_filters_duplicate_additional_scopes(
self,
):
"""Test that accidentally stored additional_authorize_scopes are filtered out."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
additional_authorize_scopes=["User.Read", "openid"],
jwt_signing_key="test-secret",
)

# If additional scopes were accidentally stored, they should be filtered
# to prevent accumulation
result = provider._prepare_scopes_for_upstream_refresh(
["read", "User.Read", "openid"]
)

# Should have: api://my-api/read (prefixed) + User.Read + openid (added once)
assert "api://my-api/read" in result
assert result.count("User.Read") == 1
assert result.count("openid") == 1
assert len(result) == 3

def test_prepare_scopes_for_upstream_refresh_mixed_scopes(self):
"""Test mixed scenario with both prefixed and unprefixed scopes."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
additional_authorize_scopes=["User.Read"],
jwt_signing_key="test-secret",
)

# Mix of prefixed and unprefixed scopes
result = provider._prepare_scopes_for_upstream_refresh(
["read", "api://other-api/admin", "write"]
)

assert "api://my-api/read" in result
assert "api://other-api/admin" in result # Already prefixed, unchanged
assert "api://my-api/write" in result
assert "User.Read" in result
assert len(result) == 4

def test_prepare_scopes_for_upstream_refresh_scope_with_slash(self):
"""Test that scopes containing '/' are not prefixed."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
jwt_signing_key="test-secret",
)

# Scopes with "/" should not be prefixed (already fully qualified)
result = provider._prepare_scopes_for_upstream_refresh(
["read", "https://graph.microsoft.com/.default"]
)

assert "api://my-api/read" in result
assert (
"https://graph.microsoft.com/.default" in result
) # Not prefixed (contains ://)

def test_prepare_scopes_for_upstream_refresh_empty_scopes(self):
"""Test behavior with empty scopes list."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
additional_authorize_scopes=["User.Read", "openid"],
jwt_signing_key="test-secret",
)

# Empty scopes should still add additional_authorize_scopes
result = provider._prepare_scopes_for_upstream_refresh([])

assert "User.Read" in result
assert "openid" in result
assert len(result) == 2

def test_prepare_scopes_for_upstream_refresh_no_additional_scopes(self):
"""Test behavior when no additional_authorize_scopes are configured."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
jwt_signing_key="test-secret",
)

# Should only prefix base scopes, no additional scopes added
result = provider._prepare_scopes_for_upstream_refresh(["read", "write"])

assert "api://my-api/read" in result
assert "api://my-api/write" in result
assert len(result) == 2

def test_prepare_scopes_for_upstream_refresh_deduplicates_scopes(self):
"""Test that duplicate scopes are deduplicated while preserving order."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
additional_authorize_scopes=["User.Read", "openid"],
jwt_signing_key="test-secret",
)

# Test with duplicate base scopes and duplicate additional scopes
result = provider._prepare_scopes_for_upstream_refresh(
["read", "write", "read", "User.Read", "openid"]
)

# Should have deduplicated results in order
assert result == [
"api://my-api/read",
"api://my-api/write",
"User.Read",
"openid",
]
assert len(result) == 4

def test_prepare_scopes_for_upstream_refresh_deduplicates_prefixed_variants(self):
"""Test that both prefixed and unprefixed variants are deduplicated."""
provider = AzureProvider(
client_id="test_client",
client_secret="test_secret",
tenant_id="test-tenant",
identifier_uri="api://my-api",
required_scopes=["read"],
jwt_signing_key="test-secret",
)

# Test with both prefixed and unprefixed variants of same scope
result = provider._prepare_scopes_for_upstream_refresh(
["read", "api://my-api/read", "write"]
)

# Should deduplicate - first occurrence wins (api://my-api/read from "read")
assert "api://my-api/read" in result
assert "api://my-api/write" in result
# Should only have 2 items (read processed twice, but deduplicated)
assert len(result) == 2
assert result.count("api://my-api/read") == 1
Loading