-
Notifications
You must be signed in to change notification settings - Fork 18
Implement OAuth endpoints #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| """ | ||
| This module contains the Starlette middleware that implemnts OAuth authorization. | ||
| """ | ||
|
|
||
| import httpx | ||
| import starlette.middleware | ||
| import starlette.middleware.base | ||
| import starlette.requests | ||
| import starlette.responses | ||
| import starlette.types | ||
|
|
||
|
|
||
| class Middleware(starlette.middleware.base.BaseHTTPMiddleware): | ||
| """ | ||
| This middleware implements the OAuth metadata and registration endpoints that MCP clients | ||
| will try to use when the server responds with a 401 status code. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| app: starlette.types.ASGIApp, | ||
| self_url: str, | ||
| oauth_url: str, | ||
| oauth_client: str, | ||
| ): | ||
| """ | ||
| Creates a new OAuth middleware. | ||
|
|
||
| Args: | ||
| app (starlette.types.ASGIApp): The starlette application. | ||
| self_url (str): Base URL of the service, as seen by clients. | ||
| oauth_url (str): Base URL of the authorization server. | ||
| oauth_client (str): The client identifier. | ||
| """ | ||
| super().__init__(app=app) | ||
| self._self_url = self_url | ||
| self._oauth_url = oauth_url | ||
| self._oauth_client = oauth_client | ||
|
|
||
| async def dispatch( | ||
| self, | ||
| request: starlette.requests.Request, | ||
| call_next: starlette.middleware.base.RequestResponseEndpoint, | ||
| ) -> starlette.responses.Response: | ||
| """ | ||
| Dispatches the request, calling the OAuth handlers or else the protected application. | ||
| """ | ||
| # The OAuth endpoints don't require authentication: | ||
| method = request.method | ||
| path = request.url.path | ||
| if method == "GET" and path == "/.well-known/oauth-protected-resource": | ||
| return await self._resource(request) | ||
| if method == "GET" and path == "/.well-known/oauth-authorization-server": | ||
| return await self._metadata(request) | ||
| if method == "POST" and path == "/oauth/register": | ||
| return await self._register(request) | ||
|
|
||
| # The rest of the endpoints do require authentication. Note that we are not validating the | ||
| # bearer token, just requiring the authorization header, so that the client will receive | ||
| # the 401 response code and trigger the OAuth flow. | ||
| auth = request.headers.get("authorization") | ||
| if auth is None: | ||
| resource_url = f"{self._self_url}/.well-known/oauth-protected-resource" | ||
| return starlette.responses.Response( | ||
| status_code=401, | ||
| headers={ | ||
| "WWW-Authenticate": f"Bearer resource_metadata=\"{resource_url}\"", | ||
| }, | ||
| ) | ||
|
|
||
| return await call_next(request) | ||
|
|
||
| async def _resource(self, request: starlette.requests.Request) -> starlette.responses.Response: | ||
| """ | ||
| This method implements the OAuth protected resource endpoint. | ||
| """ | ||
| return starlette.responses.JSONResponse( | ||
| content={ | ||
| "resource": self._self_url, | ||
| "authorization_servers": [ | ||
| self._self_url, | ||
| ], | ||
| "bearer_methods_supported": [ | ||
| "header", | ||
| ], | ||
| "scopes_supported": [ | ||
| "openid", | ||
| "api.ocm", | ||
| ], | ||
| } | ||
| ) | ||
|
|
||
| async def _metadata(self, request: starlette.requests.Request) -> starlette.responses.Response: | ||
| """ | ||
| This method implements the OAuth metadata endpoint. It gets the metadata from our real authorization | ||
| server, and replaces a few things that are needed to satisfy MCP clients. | ||
| """ | ||
| # Get the metadata from the real authorization service: | ||
| try: | ||
| async with httpx.AsyncClient() as client: | ||
| response = await client.get( | ||
| url=f"{self._oauth_url}/.well-known/oauth-authorization-server", | ||
| timeout=10, | ||
| ) | ||
| response.raise_for_status() | ||
| body = response.json() | ||
| except (httpx.RequestError, httpx.HTTPStatusError): | ||
| return starlette.responses.Response(status_code=503) | ||
|
|
||
| # The MCP clients will want to dynamically register the client, but we don't want that because our | ||
| # authorization server doesn't allow us to do it. So we replace the registration endpoint with our | ||
| # own, where we can return a fake response to make the MCP clients happy. | ||
| body["registration_endpoint"] = f"{self._self_url}/oauth/register" | ||
|
|
||
| # The MCP clients also try to request all the scopes listed in the metadata, but our authorization | ||
| # server returns a lot of scopes, and most of them will be rejected for our client. So we replace | ||
| # that large list with a much smaller list containing only the scopes that we need. | ||
| body["scopes_supported"] = [ | ||
| "openid", | ||
| "api.ocm", | ||
| ] | ||
|
|
||
| # Return the modified metadata: | ||
| return starlette.responses.JSONResponse( | ||
| content=body, | ||
| ) | ||
|
|
||
| async def _register(self, request: starlette.requests.Request) -> starlette.responses.Response: | ||
| """ | ||
| This method implements the OAuth dynamic client registration endpoint. It responds to all requests | ||
| with a fixed client identifier. | ||
| """ | ||
| body = await request.json() | ||
| redirect_uris = body.get("redirect_uris", []) | ||
| return starlette.responses.JSONResponse( | ||
| content={ | ||
| "client_id": self._oauth_client, | ||
| "redirect_uris": redirect_uris, | ||
| }, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,16 @@ | ||
| from mcp.server.fastmcp import FastMCP | ||
| import json | ||
| import os | ||
|
|
||
| import fastmcp | ||
| import fastmcp.server.dependencies | ||
| import requests | ||
| import uvicorn | ||
|
|
||
| import oauth | ||
|
|
||
| from service_client import InventoryClient | ||
|
|
||
| mcp = FastMCP("AssistedService", host="0.0.0.0") | ||
| mcp = fastmcp.FastMCP("AssistedService") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? I think the default is 127.0.0.1
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default may be 127.0.0.1, but it isn't relevant anymore because we don't run the MCP directly. Instead we create a Starlette application and start it with this: uvicorn.run(app, host="0.0.0.0", port=8000)See the |
||
|
|
||
| def get_offline_token() -> str: | ||
| """Retrieve the offline token from environment variables or request headers. | ||
|
|
@@ -26,7 +31,8 @@ def get_offline_token() -> str: | |
| if token: | ||
| return token | ||
|
|
||
| token = mcp.get_context().request_context.request.headers.get("OCM-Offline-Token") | ||
| headers = fastmcp.server.dependencies.get_http_headers() | ||
| token = headers.get("ocm-offline-token") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why change this? Also I assume this header name is not case sensitive?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As part of the change we are importing Changing the imported package was necessary because the The For more information, this is what Gemini has to say about the differences between
|
||
| if token: | ||
| return token | ||
|
|
||
|
|
@@ -46,13 +52,12 @@ def get_access_token() -> str: | |
| RuntimeError: If it isn't possible to obtain or generate the access token. | ||
| """ | ||
| # First try to get the token from the authorization header: | ||
| request = mcp.get_context().request_context.request | ||
| if request is not None: | ||
| header = request.headers.get("Authorization") | ||
| if header is not None: | ||
| parts = header.split() | ||
| if len(parts) == 2 and parts[0].lower() == "bearer": | ||
| return parts[1] | ||
| headers = fastmcp.server.dependencies.get_http_headers() | ||
| header = headers.get("authorization") | ||
| if header is not None: | ||
| parts = header.split() | ||
| if len(parts) == 2 and parts[0].lower() == "bearer": | ||
| return parts[1] | ||
|
|
||
| # Now try to get the offline token, and generate a new access token from it: | ||
| params = { | ||
|
|
@@ -300,4 +305,30 @@ def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: | |
| return InventoryClient(get_access_token()).update_host(host_id, infraenv_id, host_role=role).to_str() | ||
|
|
||
| if __name__ == "__main__": | ||
| mcp.run(transport="sse") | ||
| # We create a Starlette application so that we can add middleware: | ||
| app = mcp.http_app(transport="sse") | ||
|
|
||
| # Add the OAuth middleware if enabled: | ||
| oauth_enabled = os.getenv("OAUTH_ENABLED", "false").lower() == "true" | ||
| if oauth_enabled: | ||
| self_url = os.getenv( | ||
| "SELF_URL", | ||
| "http://localhost:8000", | ||
| ) | ||
| oauth_url = os.getenv( | ||
| "OAUTH_URL", | ||
| "https://sso.redhat.com/auth/realms/redhat-external", | ||
| ) | ||
| oauth_client = os.getenv( | ||
| "OAUTH_CLIENT", | ||
| "cloud-services", | ||
| ) | ||
| app.add_middleware( | ||
| oauth.Middleware, | ||
| self_url=self_url, | ||
| oauth_url=oauth_url, | ||
| oauth_client=oauth_client, | ||
| ) | ||
|
|
||
| # Start the application | ||
| uvicorn.run(app, host="0.0.0.0", port=8000) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for JSON parsing.
The
await request.json()call can raise exceptions if the request body is malformed JSON or missing. This should be handled gracefully.📝 Committable suggestion
🤖 Prompt for AI Agents