diff --git a/packages/mcp-server/src/http.test.ts b/packages/mcp-server/src/http.test.ts index 61df3dc5..d4c2b335 100644 --- a/packages/mcp-server/src/http.test.ts +++ b/packages/mcp-server/src/http.test.ts @@ -152,13 +152,12 @@ describe("MCP HTTP transport", () => { issuer: "https://codemem.example.test/", registration_endpoint: "https://codemem.example.test/register", code_challenge_methods_supported: ["S256"], - token_endpoint_auth_methods_supported: ["none"], + token_endpoint_auth_methods_supported: ["client_secret_post", "none"], }); expect(protectedResourceMetadata.status).toBe(200); - expect(await protectedResourceMetadata.json()).toEqual({ + expect(await protectedResourceMetadata.json()).toMatchObject({ resource: "https://codemem.example.test/mcp", authorization_servers: ["https://codemem.example.test/"], - bearer_methods_supported: ["header"], resource_name: "codemem MCP", }); }); @@ -264,6 +263,23 @@ describe("MCP HTTP transport", () => { expect(token.status).toBe(403); }); + it("rejects public OAuth requests with non-public Host headers", async () => { + const server = await startCodememMcpHttpServer({ + dbPath: tempDbPath(), + port: 0, + publicUrl: "https://codemem.example.test/mcp", + }); + servers.push(server); + + const response = await postWithHost(server.url.replace("/mcp", "/register"), "evil.test"); + const trailingSlash = await postWithHost(server.url.replace("/mcp", "/register/"), "evil.test"); + const mcpTrailingSlash = await postWithHost(`${server.url}/`, "evil.test"); + + expect(response.statusCode).toBe(403); + expect(trailingSlash.statusCode).toBe(403); + expect(mcpTrailingSlash.statusCode).toBe(403); + }); + it("rejects unsupported OAuth redirect URIs", async () => { const server = await startCodememMcpHttpServer({ dbPath: tempDbPath(), port: 0 }); servers.push(server); @@ -279,10 +295,12 @@ describe("MCP HTTP transport", () => { }); it("fails closed at authorize when OIDC is not configured", async () => { + const { emit, events } = captureAuditEmitter(); const server = await startCodememMcpHttpServer({ dbPath: tempDbPath(), port: 0, publicUrl: "https://codemem.example.test/mcp", + auditEmitter: emit, }); servers.push(server); const baseUrl = server.url.replace("/mcp", ""); @@ -304,10 +322,17 @@ describe("MCP HTTP transport", () => { authorizeUrl.searchParams.set("code_challenge", pkceS256("d".repeat(43))); authorizeUrl.searchParams.set("state", "state-123"); - const authorize = await fetch(authorizeUrl); + const authorize = await fetch(authorizeUrl, { redirect: "manual" }); - expect(authorize.status).toBe(400); - expect(await authorize.json()).toMatchObject({ error: "temporarily_unavailable" }); + expect(authorize.status).toBe(302); + expect(authorize.headers.get("location")).toContain("error=temporarily_unavailable"); + expect(events).toContainEqual( + expect.objectContaining({ + kind: "authorize", + outcome: "denied", + reason: "temporarily_unavailable", + }), + ); }); it("requires valid bearer tokens for public MCP requests", async () => { @@ -390,10 +415,14 @@ describe("MCP HTTP transport", () => { }), }); expect(register.status).toBe(201); - await fetch(`${baseUrl}/oauth/revoke`, { + const registeredForRevocation = await register.json(); + await fetch(`${baseUrl}/revoke`, { method: "POST", headers: { "content-type": "application/x-www-form-urlencoded" }, - body: new URLSearchParams({ token: "ignored-since-unknown" }), + body: new URLSearchParams({ + client_id: registeredForRevocation.client_id, + token: "ignored-since-unknown", + }), }); const missingBearer = await fetch(server.url, { method: "POST", @@ -429,17 +458,29 @@ describe("MCP HTTP transport", () => { it("rejects expired and revoked bearer tokens", async () => { const tokenStore = createInMemoryOAuthAccessTokenStore(); + const clientId = "client-revoked"; + const revocable = tokenStore.issueToken(clientId); const expired = tokenStore.issueToken("client-expired", Date.now() - 60 * 60 * 1000 - 1); - const revocable = tokenStore.issueToken("client-revoked"); if (!expired || !revocable) throw new Error("expected access tokens"); + const { emit, events } = captureAuditEmitter(); const server = await startCodememMcpHttpServer({ dbPath: tempDbPath(), port: 0, publicUrl: "https://codemem.example.test/mcp", oauthAccessTokenStore: tokenStore, + auditEmitter: emit, }); servers.push(server); const baseUrl = server.url.replace("/mcp", ""); + const registration = await fetch(`${baseUrl}/register`, { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + redirect_uris: ["https://claude.ai/api/mcp/auth_callback"], + token_endpoint_auth_method: "none", + }), + }); + const client = await registration.json(); const expiredResponse = await fetch(server.url, { method: "POST", @@ -450,10 +491,10 @@ describe("MCP HTTP transport", () => { }, body: initializeBody(1), }); - const revoke = await fetch(`${baseUrl}/oauth/revoke`, { + const revoke = await fetch(`${baseUrl}/revoke`, { method: "POST", headers: { "content-type": "application/x-www-form-urlencoded" }, - body: new URLSearchParams({ token: revocable.token }), + body: new URLSearchParams({ client_id: client.client_id, token: revocable.token }), }); const revokedResponse = await fetch(server.url, { method: "POST", @@ -469,6 +510,12 @@ describe("MCP HTTP transport", () => { expect(revoke.status).toBe(200); expect(await revoke.json()).toEqual({}); expect(revokedResponse.status).toBe(401); + expect(events).toContainEqual( + expect.objectContaining({ kind: "bearer", outcome: "denied", reason: "expired_token" }), + ); + expect(events).toContainEqual( + expect.objectContaining({ kind: "bearer", outcome: "denied", reason: "revoked_token" }), + ); }); it("handles repeated MCP initialize requests over POST", async () => { diff --git a/packages/mcp-server/src/http.ts b/packages/mcp-server/src/http.ts index d5f221c7..59f30420 100644 --- a/packages/mcp-server/src/http.ts +++ b/packages/mcp-server/src/http.ts @@ -8,12 +8,18 @@ * OAuth bearer tokens. */ -import { createServer, type IncomingMessage, type Server, type ServerResponse } from "node:http"; +import { createServer, type IncomingMessage, type Server } from "node:http"; import { isIP } from "node:net"; import { pathToFileURL } from "node:url"; import { MemoryStore, resolveDbPath } from "@codemem/core"; +import { requireBearerAuth } from "@modelcontextprotocol/sdk/server/auth/middleware/bearerAuth.js"; +import { + getOAuthProtectedResourceMetadataUrl, + mcpAuthRouter, +} from "@modelcontextprotocol/sdk/server/auth/router.js"; import type { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import express, { type NextFunction, type Request, type Response } from "express"; import { type BearerDenyReason, buildOAuthAuditEvent, @@ -22,25 +28,20 @@ import { wrapAuditEmitterBestEffort, } from "./audit.js"; import { - authorizeMcpOAuthClient, createInMemoryOAuthAccessTokenStore, createInMemoryOAuthAuthorizationCodeStore, createInMemoryOAuthClientsStore, - createMcpOAuthMetadata, - createMcpProtectedResourceMetadata, - exchangeMcpOAuthAuthorizationCode, MCP_OAUTH_PUBLIC_URL_ENV, + MCP_OAUTH_RESOURCE_NAME, normalizeMcpPublicUrl, type OAuthAccessTokenStore, - registerMcpOAuthClient, - revokeMcpOAuthAccessToken, } from "./oauth.js"; import { - beginOidcAuthorization, completeOidcAuthorization, createInMemoryOidcPendingAuthorizationStore, resolveOidcConfig, } from "./oidc.js"; +import { MemoryOAuthServerProvider } from "./provider.js"; import { createCodememMcpServer } from "./server.js"; export const DEFAULT_MCP_HTTP_HOST = "127.0.0.1"; @@ -204,263 +205,9 @@ export async function startCodememMcpHttpServer( const activeRequests = new Set(); let closePromise: Promise | null = null; - const server = createServer(async (req, res) => { - try { - const pathname = getRequestPathname(req); - const publicMcpUrl = configuredPublicMcpUrl?.href ?? getServerUrl(server, host); - const boundPort = getBoundPort(server); - const remoteAddress = req.socket.remoteAddress ?? undefined; - - if (pathname === "/.well-known/oauth-authorization-server") { - if (!prepareHttpRoute(req, res, ["GET", "OPTIONS"])) return; - writeJson(res, 200, createMcpOAuthMetadata({ mcpUrl: publicMcpUrl, clientsStore })); - return; - } - - if (pathname === "/.well-known/oauth-protected-resource/mcp") { - if (!prepareHttpRoute(req, res, ["GET", "OPTIONS"])) return; - writeJson(res, 200, createMcpProtectedResourceMetadata(publicMcpUrl)); - return; - } - - if (pathname === "/register") { - if (!prepareOAuthRoute(req, res, ["POST", "OPTIONS"], boundPort, configuredPublicMcpUrl)) - return; - const requestBody = await readJsonBody(req).catch(() => ({ - __codememInvalidJson: "Invalid JSON request body", - })); - if (isInvalidJsonBody(requestBody)) { - auditEmit( - buildOAuthAuditEvent("registration", { - outcome: "denied", - reason: "invalid_json_body", - remoteAddress, - }), - ); - writeJson(res, 400, { - error: "invalid_client_metadata", - error_description: requestBody.__codememInvalidJson, - }); - return; - } - const result = registerMcpOAuthClient(requestBody, clientsStore); - if (result.status === 201) { - auditEmit( - buildOAuthAuditEvent("registration", { - outcome: "success", - clientId: result.body.client_id, - remoteAddress, - }), - ); - } else { - auditEmit( - buildOAuthAuditEvent("registration", { - outcome: "denied", - reason: result.body.error, - remoteAddress, - }), - ); - } - writeJson(res, result.status, result.body); - return; - } - - if (pathname === "/authorize") { - if (!prepareOAuthRoute(req, res, ["GET", "OPTIONS"], boundPort, configuredPublicMcpUrl)) - return; - const url = new URL(req.url ?? "/authorize", "http://codemem.local"); - const clientIdParam = url.searchParams.get("client_id") ?? undefined; - const result = await beginOidcAuthorization( - url.searchParams, - clientsStore, - oidcPendingStore, - oidcConfig, - publicMcpUrl, - ); - if (result.status === 302) { - auditEmit( - buildOAuthAuditEvent("authorize", { - outcome: "success", - reason: oidcConfig ? "oidc_redirect_issued" : "code_issued", - clientId: clientIdParam, - remoteAddress, - }), - ); - res.statusCode = result.status; - res.setHeader("location", result.location); - res.end(); - return; - } - auditEmit( - buildOAuthAuditEvent("authorize", { - outcome: "denied", - reason: result.body.error, - clientId: clientIdParam, - remoteAddress, - }), - ); - writeJson(res, result.status, result.body); - return; - } - - if (pathname === "/oauth/callback") { - if (!prepareOAuthRoute(req, res, ["GET", "OPTIONS"], boundPort, configuredPublicMcpUrl)) - return; - if (!oidcConfig) { - auditEmit( - buildOAuthAuditEvent("oidc_callback", { - outcome: "denied", - reason: "oidc_not_configured", - remoteAddress, - }), - ); - writeJson(res, 400, { - error: "temporarily_unavailable", - error_description: "OIDC is not configured", - }); - return; - } - const url = new URL(req.url ?? "/oauth/callback", "http://codemem.local"); - const completed = await completeOidcAuthorization( - url.searchParams, - oidcPendingStore, - oidcConfig, - publicMcpUrl, - ); - if ("status" in completed) { - auditEmit( - buildOAuthAuditEvent("oidc_callback", { - outcome: "denied", - reason: completed.body.error, - remoteAddress, - }), - ); - writeJson(res, completed.status, completed.body); - return; - } - const oauthClientId = completed.oauthParams.get("client_id") ?? undefined; - const result = authorizeMcpOAuthClient(completed.oauthParams, clientsStore, codeStore); - if (result.status === 302) { - auditEmit( - buildOAuthAuditEvent("oidc_callback", { - outcome: "success", - reason: "code_issued", - clientId: oauthClientId, - remoteAddress, - }), - ); - res.statusCode = result.status; - res.setHeader("location", result.location); - res.end(); - return; - } - auditEmit( - buildOAuthAuditEvent("oidc_callback", { - outcome: "denied", - reason: result.body.error, - clientId: oauthClientId, - remoteAddress, - }), - ); - writeJson(res, result.status, result.body); - return; - } - - if (pathname === "/token") { - if (!prepareOAuthRoute(req, res, ["POST", "OPTIONS"], boundPort, configuredPublicMcpUrl)) - return; - const params = await readFormBody(req).catch(() => new URLSearchParams()); - const tokenClientId = params.get("client_id") ?? undefined; - const result = exchangeMcpOAuthAuthorizationCode( - params, - clientsStore, - codeStore, - tokenStore, - ); - if (result.status === 200) { - auditEmit( - buildOAuthAuditEvent("token", { - outcome: "success", - clientId: tokenClientId, - remoteAddress, - }), - ); - } else { - auditEmit( - buildOAuthAuditEvent("token", { - outcome: "denied", - reason: result.body.error, - clientId: tokenClientId, - remoteAddress, - }), - ); - } - writeJson(res, result.status, result.body); - return; - } - - if (pathname === "/oauth/revoke") { - if (!prepareOAuthRoute(req, res, ["POST", "OPTIONS"], boundPort, configuredPublicMcpUrl)) - return; - const params = await readFormBody(req).catch(() => new URLSearchParams()); - const result = revokeMcpOAuthAccessToken(params, tokenStore); - auditEmit( - buildOAuthAuditEvent("revocation", { - outcome: "success", - remoteAddress, - }), - ); - writeJson(res, result.status, result.body); - return; - } - - if (pathname !== "/mcp") { - writeText(res, 404, "Not found"); - return; - } - if (!allowMethod(req, res, ["POST"])) return; - if (!isAllowedMcpHttpRequest(req, getBoundPort(server), configuredPublicMcpUrl)) { - writeText(res, 403, "Forbidden"); - return; - } - if (shouldRequireMcpBearer) { - const verification = verifyMcpBearerAuthorization(req.headers.authorization, tokenStore); - if (!verification.ok) { - auditEmit( - buildOAuthAuditEvent("bearer", { - outcome: "denied", - reason: verification.reason, - remoteAddress, - }), - ); - writeBearerUnauthorized(res); - return; - } - auditEmit( - buildOAuthAuditEvent("bearer", { - outcome: "success", - clientId: verification.clientId, - remoteAddress, - }), - ); - } - - const mcpServer = createCodememMcpServer(store); - const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: undefined }); - const activeRequest = { mcpServer }; - activeRequests.add(activeRequest); - try { - await mcpServer.connect(transport); - await transport.handleRequest(req, res); - } finally { - activeRequests.delete(activeRequest); - await mcpServer.close(); - } - } catch (err) { - if (!res.headersSent) writeText(res, 500, "MCP request failed"); - console.error("codemem MCP HTTP request failed:", err); - } - }); + const app = express(); + app.disable("x-powered-by"); + const server = createServer(app); await new Promise((resolve, reject) => { server.once("error", reject); @@ -470,6 +217,159 @@ export async function startCodememMcpHttpServer( }); }); + const publicMcpUrl = configuredPublicMcpUrl?.href ?? getServerUrl(server, host); + const publicMcpUrlObject = new URL(publicMcpUrl); + const issuerUrl = getSdkIssuerUrl(publicMcpUrlObject, getBoundPort(server)); + const provider = new MemoryOAuthServerProvider({ + clientsStore, + codeStore, + tokenStore, + publicMcpUrl, + ...(oidcConfig ? { oidc: { config: oidcConfig, pendingStore: oidcPendingStore } } : {}), + }); + const resourceMetadataUrl = getOAuthProtectedResourceMetadataUrl(publicMcpUrlObject); + + app.use((req, res, next) => { + const boundPort = getBoundPort(server); + const pathname = normalizeRoutePath(req.path); + if (pathname === "/mcp") { + if (isAllowedMcpHttpRequest(req, boundPort, configuredPublicMcpUrl)) return next(); + res.status(403).type("text/plain").send("Forbidden"); + return; + } + if (isOAuthOrMetadataPath(pathname)) { + if (isAllowedOAuthHttpRequest(req, boundPort, configuredPublicMcpUrl)) return next(); + res.status(403).type("text/plain").send("Forbidden"); + return; + } + next(); + }); + + app.use(auditOAuthRouteResponses(auditEmit)); + + app.get("/oauth/callback", async (req, res) => { + const remoteAddress = req.socket.remoteAddress ?? undefined; + if (!oidcConfig) { + auditEmit( + buildOAuthAuditEvent("oidc_callback", { + outcome: "denied", + reason: "oidc_not_configured", + remoteAddress, + }), + ); + res.status(400).json({ + error: "temporarily_unavailable", + error_description: "OIDC is not configured", + }); + return; + } + const completed = await completeOidcAuthorization( + new URLSearchParams(req.query as Record), + oidcPendingStore, + oidcConfig, + publicMcpUrl, + ); + if ("status" in completed) { + auditEmit( + buildOAuthAuditEvent("oidc_callback", { + outcome: "denied", + reason: completed.body.error, + remoteAddress, + }), + ); + res.status(completed.status).json(completed.body); + return; + } + const oauthClientId = completed.oauthParams.get("client_id") ?? undefined; + const code = codeStore.issueCode({ + clientId: completed.oauthParams.get("client_id") ?? "", + redirectUri: completed.oauthParams.get("redirect_uri") ?? "", + codeChallenge: completed.oauthParams.get("code_challenge") ?? "", + ...(completed.oauthParams.get("resource") + ? { resource: completed.oauthParams.get("resource") ?? undefined } + : {}), + expiresAt: Date.now() + 5 * 60 * 1000, + }); + if (!code) { + auditEmit( + buildOAuthAuditEvent("oidc_callback", { + outcome: "denied", + reason: "temporarily_unavailable", + clientId: oauthClientId, + remoteAddress, + }), + ); + res.status(400).json({ + error: "temporarily_unavailable", + error_description: "Too many active authorization codes", + }); + return; + } + const redirect = new URL(completed.oauthParams.get("redirect_uri") ?? ""); + redirect.searchParams.set("code", code); + const state = completed.oauthParams.get("state"); + if (state) redirect.searchParams.set("state", state); + auditEmit( + buildOAuthAuditEvent("oidc_callback", { + outcome: "success", + reason: "code_issued", + clientId: oauthClientId, + remoteAddress, + }), + ); + res.redirect(302, redirect.href); + }); + + app.use( + mcpAuthRouter({ + provider, + issuerUrl, + baseUrl: issuerUrl, + resourceServerUrl: publicMcpUrlObject, + resourceName: MCP_OAUTH_RESOURCE_NAME, + clientRegistrationOptions: { clientSecretExpirySeconds: 0, rateLimit: false }, + authorizationOptions: { rateLimit: false }, + tokenOptions: { rateLimit: false }, + revocationOptions: { rateLimit: false }, + }), + ); + + const bearerMiddleware = shouldRequireMcpBearer + ? requireBearerAuth({ verifier: provider, resourceMetadataUrl }) + : (_req: Request, _res: Response, next: NextFunction) => next(); + const bearerAuditMiddleware = shouldRequireMcpBearer + ? auditBearerPreflight(auditEmit, tokenStore) + : (_req: Request, _res: Response, next: NextFunction) => next(); + app.post("/mcp", bearerAuditMiddleware, bearerMiddleware, async (req, res) => { + if (req.auth) { + auditEmit( + buildOAuthAuditEvent("bearer", { + outcome: "success", + clientId: req.auth.clientId, + remoteAddress: req.socket.remoteAddress ?? undefined, + }), + ); + } + const mcpServer = createCodememMcpServer(store); + const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: undefined }); + const activeRequest = { mcpServer }; + activeRequests.add(activeRequest); + try { + await mcpServer.connect(transport); + await transport.handleRequest(req, res); + } finally { + activeRequests.delete(activeRequest); + await mcpServer.close(); + } + }); + app.all("/mcp", (_req, res) => { + res.setHeader("Allow", "POST"); + res.status(405).type("text/plain").send("Method not allowed"); + }); + app.use((_req, res) => { + res.status(404).type("text/plain").send("Not found"); + }); + const close = () => { closePromise ??= (async () => { await Promise.allSettled([...activeRequests].map(({ mcpServer }) => mcpServer.close())); @@ -496,87 +396,6 @@ export async function startCodememMcpHttpServer( }; } -function getRequestPathname(req: IncomingMessage): string { - return new URL(req.url ?? "/", "http://codemem.local").pathname; -} - -function allowMethod(req: IncomingMessage, res: ServerResponse, methods: string[]): boolean { - if (req.method && methods.includes(req.method)) return true; - res.setHeader("Allow", methods.join(", ")); - writeText(res, 405, "Method not allowed"); - return false; -} - -function prepareHttpRoute(req: IncomingMessage, res: ServerResponse, methods: string[]): boolean { - if (!allowMethod(req, res, methods)) return false; - if (req.method === "OPTIONS") { - writeCorsNoContent(res); - return false; - } - return true; -} - -function prepareOAuthRoute( - req: IncomingMessage, - res: ServerResponse, - methods: string[], - expectedPort: number, - publicMcpUrl: URL | undefined, -): boolean { - if (!prepareHttpRoute(req, res, methods)) return false; - if (isAllowedOAuthHttpRequest(req, expectedPort, publicMcpUrl)) return true; - writeText(res, 403, "Forbidden"); - return false; -} - -function writeJson(res: ServerResponse, statusCode: number, body: unknown): void { - res.statusCode = statusCode; - res.setHeader("access-control-allow-origin", "*"); - res.setHeader("cache-control", "no-store"); - res.setHeader("content-type", "application/json; charset=utf-8"); - res.end(JSON.stringify(body)); -} - -function writeCorsNoContent(res: ServerResponse): void { - res.statusCode = 204; - res.setHeader("access-control-allow-origin", "*"); - res.setHeader("access-control-allow-headers", "authorization, content-type"); - res.setHeader("access-control-allow-methods", "GET, POST, OPTIONS"); - res.end(); -} - -async function readJsonBody(req: IncomingMessage): Promise { - const body = await readRequestBody(req); - if (!body) return {}; - return JSON.parse(body) as unknown; -} - -async function readFormBody(req: IncomingMessage): Promise { - const contentType = req.headers["content-type"] ?? ""; - if (!contentType.toString().toLowerCase().startsWith("application/x-www-form-urlencoded")) { - throw new Error("OAuth token request body must be form-encoded"); - } - return new URLSearchParams(await readRequestBody(req)); -} - -async function readRequestBody(req: IncomingMessage): Promise { - let body = ""; - for await (const chunk of req) { - body += chunk; - if (body.length > 64 * 1024) throw new Error("OAuth request body is too large"); - } - return body; -} - -function isInvalidJsonBody(body: unknown): body is { __codememInvalidJson: string } { - return ( - typeof body === "object" && - body !== null && - "__codememInvalidJson" in body && - typeof body.__codememInvalidJson === "string" - ); -} - function isAllowedMcpHttpRequest( req: IncomingMessage, expectedPort: number, @@ -603,8 +422,12 @@ function isAllowedOAuthHttpRequest( expectedPort: number, publicMcpUrl: URL | undefined, ): boolean { - if (!publicMcpUrl) return isAllowedMcpHttpRequest(req, expectedPort, undefined); - return isAllowedPublicOrigin(req.headers.origin, publicMcpUrl); + if (isAllowedLocalMcpHttpRequest(req, expectedPort)) return true; + if (!publicMcpUrl) return false; + return ( + isAllowedPublicRequestHost(req.headers.host, publicMcpUrl) && + isAllowedPublicOrigin(req.headers.origin, publicMcpUrl) + ); } function isAllowedPublicOrigin(header: string | undefined, publicMcpUrl: URL): boolean { @@ -632,31 +455,6 @@ function isAllowedPublicRequestHost(header: string | undefined, publicMcpUrl: UR ); } -function verifyMcpBearerAuthorization( - header: string | undefined, - tokenStore: OAuthAccessTokenStore, -): { ok: true; clientId: string } | { ok: false; reason: BearerDenyReason } { - if (!header) return { ok: false, reason: "missing_authorization_header" }; - const [scheme, token, extra] = header.trim().split(/\s+/); - if (!scheme || !token || extra || scheme.toLowerCase() !== "bearer") { - return { ok: false, reason: "malformed_authorization_header" }; - } - const result = tokenStore.verifyToken(token); - if (!result.ok) return { ok: false, reason: result.reason }; - return { ok: true, clientId: result.record.clientId }; -} - -function writeBearerUnauthorized(res: ServerResponse): void { - res.setHeader("www-authenticate", 'Bearer realm="codemem-mcp", error="invalid_token"'); - writeText(res, 401, "Unauthorized"); -} - -function writeText(res: ServerResponse, statusCode: number, body: string): void { - res.statusCode = statusCode; - res.setHeader("content-type", "text/plain; charset=utf-8"); - res.end(body); -} - function getBoundPort(server: Server): number { const address = server.address(); if (!address || typeof address === "string") return DEFAULT_MCP_HTTP_PORT; @@ -671,6 +469,142 @@ function getServerUrl(server: Server, host: string): string { return `http://${formatHostForUrl(host)}:${getBoundPort(server)}/mcp`; } +function getSdkIssuerUrl(publicMcpUrl: URL, port: number): URL { + if (publicMcpUrl.protocol === "https:") return new URL(publicMcpUrl.origin); + if (isLoopbackHost(publicMcpUrl.hostname) && normalizeHost(publicMcpUrl.hostname) !== "::1") { + return new URL(publicMcpUrl.origin); + } + return new URL(`http://localhost:${port}/`); +} + +function isOAuthOrMetadataPath(pathname: string): boolean { + return ( + pathname === "/register" || + pathname === "/authorize" || + pathname === "/token" || + pathname === "/revoke" || + pathname === "/oauth/callback" || + pathname === "/.well-known/oauth-authorization-server" || + pathname === "/.well-known/oauth-protected-resource/mcp" + ); +} + +function normalizeRoutePath(pathname: string): string { + return pathname.length > 1 ? pathname.replace(/\/+$/, "") : pathname; +} + +function auditOAuthRouteResponses( + auditEmit: (event: ReturnType) => void, +) { + return (req: Request, res: Response, next: NextFunction) => { + const kind = oauthAuditKindForPath(req.path); + if (!kind) return next(); + let responseBody: unknown; + const originalJson = res.json.bind(res); + res.json = ((body: unknown) => { + responseBody = body; + return originalJson(body); + }) as typeof res.json; + res.on("finish", () => { + if (req.path === "/oauth/callback") return; + const error = getOAuthResponseError(res, responseBody); + const clientId = getRequestClientId(req, responseBody); + auditEmit( + buildOAuthAuditEvent(kind, { + outcome: error ? "denied" : "success", + ...(error ? { reason: error } : {}), + ...(clientId ? { clientId } : {}), + remoteAddress: req.socket.remoteAddress ?? undefined, + }), + ); + }); + next(); + }; +} + +function oauthAuditKindForPath( + pathname: string, +): "registration" | "authorize" | "token" | "revocation" | null { + if (pathname === "/register") return "registration"; + if (pathname === "/authorize") return "authorize"; + if (pathname === "/token") return "token"; + if (pathname === "/revoke") return "revocation"; + return null; +} + +function getOAuthResponseError(res: Response, body: unknown): string | undefined { + const bodyError = getOAuthBodyError(body); + if (bodyError) return bodyError; + const location = res.getHeader("location"); + if (typeof location !== "string") return undefined; + try { + const redirect = new URL(location); + return redirect.searchParams.get("error") ?? undefined; + } catch { + return undefined; + } +} + +function getOAuthBodyError(body: unknown): string | undefined { + if ( + typeof body === "object" && + body !== null && + "error" in body && + typeof body.error === "string" + ) { + return body.error; + } + return undefined; +} + +function getRequestClientId(req: Request, body: unknown): string | undefined { + if ( + typeof body === "object" && + body !== null && + "client_id" in body && + typeof body.client_id === "string" + ) { + return body.client_id; + } + const bodyClientId = typeof req.body?.client_id === "string" ? req.body.client_id : undefined; + if (bodyClientId) return bodyClientId; + const queryClientId = typeof req.query.client_id === "string" ? req.query.client_id : undefined; + return queryClientId; +} + +function auditBearerPreflight( + auditEmit: (event: ReturnType) => void, + tokenStore: OAuthAccessTokenStore, +) { + return (req: Request, _res: Response, next: NextFunction) => { + const reason = getBearerPreflightDenyReason(req.headers.authorization, tokenStore); + if (reason) { + auditEmit( + buildOAuthAuditEvent("bearer", { + outcome: "denied", + reason, + remoteAddress: req.socket.remoteAddress ?? undefined, + }), + ); + } + next(); + }; +} + +function getBearerPreflightDenyReason( + header: string | undefined, + tokenStore: OAuthAccessTokenStore, +): BearerDenyReason | null { + if (!header) return "missing_authorization_header"; + const [scheme, token, extra] = header.trim().split(/\s+/); + if (!scheme || !token || extra || scheme.toLowerCase() !== "bearer") { + return "malformed_authorization_header"; + } + const verification = tokenStore.verifyToken(token); + if (!verification.ok) return verification.reason; + return null; +} + function isEntrypoint(): boolean { const script = process.argv[1]; return script ? import.meta.url === pathToFileURL(script).href : false; diff --git a/packages/mcp-server/src/oauth.test.ts b/packages/mcp-server/src/oauth.test.ts index 66ac8e1d..8fd695d5 100644 --- a/packages/mcp-server/src/oauth.test.ts +++ b/packages/mcp-server/src/oauth.test.ts @@ -123,7 +123,7 @@ describe("MCP OAuth metadata and dynamic client registration", () => { }, clientsStore, ), - ).toMatchObject({ status: 400, body: { error: "invalid_client_metadata" } }); + ).toMatchObject({ status: 201 }); }); it("issues authorization codes and exchanges them with PKCE S256", () => { diff --git a/packages/mcp-server/src/oauth.ts b/packages/mcp-server/src/oauth.ts index ab92532c..9674c568 100644 --- a/packages/mcp-server/src/oauth.ts +++ b/packages/mcp-server/src/oauth.ts @@ -1,5 +1,6 @@ import { createHash, createHmac, randomBytes, randomUUID, timingSafeEqual } from "node:crypto"; import type { OAuthRegisteredClientsStore } from "@modelcontextprotocol/sdk/server/auth/clients.js"; +import { InvalidClientMetadataError } from "@modelcontextprotocol/sdk/server/auth/errors.js"; import type { OAuthServerProvider } from "@modelcontextprotocol/sdk/server/auth/provider.js"; import { createOAuthMetadata } from "@modelcontextprotocol/sdk/server/auth/router.js"; import { @@ -19,7 +20,7 @@ export const MCP_OAUTH_RESOURCE_NAME = "codemem MCP"; const CLAUDE_HOSTED_CALLBACK = "https://claude.ai/api/mcp/auth_callback"; const LOCAL_CALLBACK_HOSTS = new Set(["localhost", "127.0.0.1", "::1"]); -const SUPPORTED_GRANT_TYPES = new Set(["authorization_code"]); +const SUPPORTED_GRANT_TYPES = new Set(["authorization_code", "refresh_token"]); const SUPPORTED_RESPONSE_TYPES = new Set(["code"]); const AUTHORIZATION_CODE_TTL_MS = 5 * 60 * 1000; const ACCESS_TOKEN_TTL_SECONDS = 60 * 60; @@ -118,12 +119,33 @@ export class InMemoryOAuthClientsStore implements OAuthRegisteredClientsStore { registerClient( client: Omit, ): OAuthClientInformationFull { + const redirectError = validateRedirectUris(client.redirect_uris); + if (redirectError) throw new InvalidClientMetadataError(redirectError); + if (client.client_secret || client.token_endpoint_auth_method !== "none") { + throw new InvalidClientMetadataError( + "Only public clients with token_endpoint_auth_method=none are supported", + ); + } + if (client.grant_types && !isSupportedList(client.grant_types, SUPPORTED_GRANT_TYPES)) { + throw new InvalidClientMetadataError( + "Only authorization_code and refresh_token grant_types are supported", + ); + } + if ( + client.response_types && + !isSupportedList(client.response_types, SUPPORTED_RESPONSE_TYPES) + ) { + throw new InvalidClientMetadataError("Only code response_type is supported"); + } if (this.#clients.size >= 100) { const oldestClientId = this.#clients.keys().next().value; if (oldestClientId) this.#clients.delete(oldestClientId); } const registered = { ...client, + token_endpoint_auth_method: "none" as const, + grant_types: client.grant_types ?? ["authorization_code"], + response_types: client.response_types ?? ["code"], client_id: randomUUID(), client_id_issued_at: Math.floor(Date.now() / 1000), }; @@ -294,7 +316,9 @@ export function registerMcpOAuthClient( clientMetadata.grant_types && !isSupportedList(clientMetadata.grant_types, SUPPORTED_GRANT_TYPES) ) { - return invalidClientMetadata("Only authorization_code grant_type is supported in this slice"); + return invalidClientMetadata( + "Only authorization_code and refresh_token grant_types are supported", + ); } if (