diff --git a/server/src/index.ts b/server/src/index.ts index c0fb3797a..3f739f39e 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -22,12 +22,6 @@ import mcpProxy from "./mcpProxy.js"; import { randomUUID, randomBytes, timingSafeEqual } from "node:crypto"; const DEFAULT_MCP_PROXY_LISTEN_PORT = "6277"; -const SSE_HEADERS_PASSTHROUGH = ["authorization"]; -const STREAMABLE_HTTP_HEADERS_PASSTHROUGH = [ - "authorization", - "mcp-session-id", - "last-event-id", -]; const defaultEnvironment = { ...getDefaultEnvironment(), @@ -46,43 +40,84 @@ const { values } = parseArgs({ }); // Function to get HTTP headers. -// Supports only "sse" and "streamable-http" transport types. -const getHttpHeaders = ( - req: express.Request, - transportType: string, -): HeadersInit => { - const headers: HeadersInit = { - Accept: - transportType === "sse" - ? "text/event-stream" - : "text/event-stream, application/json", - }; - const defaultHeaders = - transportType === "sse" - ? SSE_HEADERS_PASSTHROUGH - : STREAMABLE_HTTP_HEADERS_PASSTHROUGH; - - for (const key of defaultHeaders) { - if (req.headers[key] === undefined) { - continue; +const getHttpHeaders = (req: express.Request): Record => { + const headers: Record = {}; + + // Iterate over all headers in the request + for (const key in req.headers) { + const lowerKey = key.toLowerCase(); + + // Check if the header is one we want to forward + if ( + lowerKey.startsWith("mcp-") || + lowerKey === "authorization" || + lowerKey === "last-event-id" + ) { + // Exclude the proxy's own authentication header + if (lowerKey !== "x-mcp-proxy-auth") { + const value = req.headers[key]; + + if (typeof value === "string") { + // If the value is a string, use it directly + headers[key] = value; + } else if (Array.isArray(value)) { + // If the value is an array, use the last element + const lastValue = value.at(-1); + if (lastValue !== undefined) { + headers[key] = lastValue; + } + } + // If value is undefined, it's skipped, which is correct. + } } - - const value = req.headers[key]; - headers[key] = Array.isArray(value) ? value[value.length - 1] : value; } - // If the header "x-custom-auth-header" is present, use its value as the custom header name. - if (req.headers["x-custom-auth-header"] !== undefined) { - const customHeaderName = req.headers["x-custom-auth-header"] as string; - const lowerCaseHeaderName = customHeaderName.toLowerCase(); - if (req.headers[lowerCaseHeaderName] !== undefined) { - const value = req.headers[lowerCaseHeaderName]; - headers[customHeaderName] = value as string; + // Handle the custom auth header separately. We expect `x-custom-auth-header` + // to be a string containing the name of the actual authentication header. + const customAuthHeaderName = req.headers["x-custom-auth-header"]; + if (typeof customAuthHeaderName === "string") { + const lowerCaseHeaderName = customAuthHeaderName.toLowerCase(); + const value = req.headers[lowerCaseHeaderName]; + + if (typeof value === "string") { + headers[customAuthHeaderName] = value; + } else if (Array.isArray(value)) { + // If the actual auth header was sent multiple times, use the last value. + const lastValue = value.at(-1); + if (lastValue !== undefined) { + headers[customAuthHeaderName] = lastValue; + } } } + return headers; }; +/** + * Updates a headers object in-place, preserving the original Accept header. + * This is necessary to ensure that transports holding a reference to the headers + * object see the updates. + * @param currentHeaders The headers object to update. + * @param newHeaders The new headers to apply. + */ +const updateHeadersInPlace = ( + currentHeaders: Record, + newHeaders: Record, +) => { + // Preserve the Accept header, which is set at transport creation and + // is not present in subsequent client requests. + const accept = currentHeaders["Accept"]; + + // Clear the old headers and apply the new ones. + Object.keys(currentHeaders).forEach((key) => delete currentHeaders[key]); + Object.assign(currentHeaders, newHeaders); + + // Restore the Accept header. + if (accept) { + currentHeaders["Accept"] = accept; + } +}; + const app = express(); app.use(cors()); app.use((req, res, next) => { @@ -92,6 +127,7 @@ app.use((req, res, next) => { const webAppTransports: Map = new Map(); // Web app transports by web app sessionId const serverTransports: Map = new Map(); // Server Transports by web app sessionId +const sessionHeaderHolders: Map = new Map(); // For dynamic header updates // Use provided token from environment or generate a new one const sessionToken = @@ -174,7 +210,38 @@ const authMiddleware = ( next(); }; -const createTransport = async (req: express.Request): Promise => { +/** + * Creates a `fetch` function that merges dynamic session headers with the + * headers from the actual request, ensuring that request-specific headers like + * `Content-Type` are preserved. + */ +const createCustomFetch = (headerHolder: { headers: HeadersInit }) => { + return (input: RequestInfo | URL, init?: RequestInit): Promise => { + // Determine the headers from the original request/init. + // The SDK may pass a Request object or a URL and an init object. + const originalHeaders = + input instanceof Request ? input.headers : init?.headers; + + // Start with our dynamic session headers. + const finalHeaders = new Headers(headerHolder.headers); + + // Merge the SDK's request-specific headers, letting them overwrite. + // This is crucial for preserving Content-Type on POST requests. + new Headers(originalHeaders).forEach((value, key) => { + finalHeaders.set(key, value); + }); + + // This works for both `fetch(url, init)` and `fetch(request)` style calls. + return fetch(input, { ...init, headers: finalHeaders }); + }; +}; + +const createTransport = async ( + req: express.Request, +): Promise<{ + transport: Transport; + headerHolder?: { headers: HeadersInit }; +}> => { const query = req.query; console.log("Query parameters:", JSON.stringify(query)); @@ -198,11 +265,13 @@ const createTransport = async (req: express.Request): Promise => { }); await transport.start(); - return transport; + return { transport }; } else if (transportType === "sse") { const url = query.url as string; - const headers = getHttpHeaders(req, transportType); + const headers = getHttpHeaders(req); + headers["Accept"] = "text/event-stream"; + const headerHolder = { headers }; console.log( `SSE transport: url=${url}, headers=${JSON.stringify(headers)}`, @@ -210,27 +279,28 @@ const createTransport = async (req: express.Request): Promise => { const transport = new SSEClientTransport(new URL(url), { eventSourceInit: { - fetch: (url, init) => fetch(url, { ...init, headers }), + fetch: createCustomFetch(headerHolder), }, requestInit: { - headers, + headers: headerHolder.headers, }, }); await transport.start(); - return transport; + return { transport, headerHolder }; } else if (transportType === "streamable-http") { - const headers = getHttpHeaders(req, transportType); + const headers = getHttpHeaders(req); + headers["Accept"] = "text/event-stream, application/json"; + const headerHolder = { headers }; const transport = new StreamableHTTPClientTransport( new URL(query.url as string), { - requestInit: { - headers, - }, + // Pass a custom fetch to inject the latest headers on each request + fetch: createCustomFetch(headerHolder), }, ); await transport.start(); - return transport; + return { transport, headerHolder }; } else { console.error(`Invalid transport type: ${transportType}`); throw new Error("Invalid transport type specified"); @@ -244,6 +314,15 @@ app.get( async (req, res) => { const sessionId = req.headers["mcp-session-id"] as string; console.log(`Received GET message for sessionId ${sessionId}`); + + const headerHolder = sessionHeaderHolders.get(sessionId); + if (headerHolder) { + updateHeadersInPlace( + headerHolder.headers as Record, + getHttpHeaders(req), + ); + } + try { const transport = webAppTransports.get( sessionId, @@ -267,34 +346,54 @@ app.post( authMiddleware, async (req, res) => { const sessionId = req.headers["mcp-session-id"] as string | undefined; - let serverTransport: Transport | undefined; - if (!sessionId) { - try { - console.log("New StreamableHttp connection request"); - try { - serverTransport = await createTransport(req); - } catch (error) { - if (error instanceof SseError && error.code === 401) { - console.error( - "Received 401 Unauthorized from MCP server:", - error.message, - ); - res.status(401).json(error); - return; - } - throw error; - } + if (sessionId) { + console.log(`Received POST message for sessionId ${sessionId}`); + const headerHolder = sessionHeaderHolders.get(sessionId); + if (headerHolder) { + updateHeadersInPlace( + headerHolder.headers as Record, + getHttpHeaders(req), + ); + } - console.log("Created StreamableHttp server transport"); + try { + const transport = webAppTransports.get( + sessionId, + ) as StreamableHTTPServerTransport; + if (!transport) { + res.status(404).end("Transport not found for sessionId " + sessionId); + } else { + await (transport as StreamableHTTPServerTransport).handleRequest( + req, + res, + ); + } + } catch (error) { + console.error("Error in /mcp route:", error); + res.status(500).json(error); + } + } else { + console.log("New StreamableHttp connection request"); + try { + const { transport: serverTransport, headerHolder } = + await createTransport(req); const webAppTransport = new StreamableHTTPServerTransport({ sessionIdGenerator: randomUUID, onsessioninitialized: (sessionId) => { webAppTransports.set(sessionId, webAppTransport); - serverTransports.set(sessionId, serverTransport!); + serverTransports.set(sessionId, serverTransport!); // eslint-disable-line @typescript-eslint/no-non-null-assertion + if (headerHolder) { + sessionHeaderHolders.set(sessionId, headerHolder); + } console.log("Client <-> Proxy sessionId: " + sessionId); }, + onsessionclosed: (sessionId) => { + webAppTransports.delete(sessionId); + serverTransports.delete(sessionId); + sessionHeaderHolders.delete(sessionId); + }, }); console.log("Created StreamableHttp client transport"); @@ -311,25 +410,15 @@ app.post( req.body, ); } catch (error) { - console.error("Error in /mcp POST route:", error); - res.status(500).json(error); - } - } else { - console.log(`Received POST message for sessionId ${sessionId}`); - try { - const transport = webAppTransports.get( - sessionId, - ) as StreamableHTTPServerTransport; - if (!transport) { - res.status(404).end("Transport not found for sessionId " + sessionId); - } else { - await (transport as StreamableHTTPServerTransport).handleRequest( - req, - res, + if (error instanceof SseError && error.code === 401) { + console.error( + "Received 401 Unauthorized from MCP server:", + error.message, ); + res.status(401).json(error); + return; } - } catch (error) { - console.error("Error in /mcp route:", error); + console.error("Error in /mcp POST route:", error); res.status(500).json(error); } } @@ -343,20 +432,18 @@ app.delete( async (req, res) => { const sessionId = req.headers["mcp-session-id"] as string | undefined; console.log(`Received DELETE message for sessionId ${sessionId}`); - let serverTransport: Transport | undefined; if (sessionId) { try { - serverTransport = serverTransports.get( + const serverTransport = serverTransports.get( sessionId, ) as StreamableHTTPClientTransport; if (!serverTransport) { res.status(404).end("Transport not found for sessionId " + sessionId); } else { - await ( - serverTransport as StreamableHTTPClientTransport - ).terminateSession(); + await serverTransport.terminateSession(); webAppTransports.delete(sessionId); serverTransports.delete(sessionId); + sessionHeaderHolders.delete(sessionId); console.log(`Transports removed for sessionId ${sessionId}`); } res.status(200).end(); @@ -375,20 +462,7 @@ app.get( async (req, res) => { try { console.log("New STDIO connection request"); - let serverTransport: Transport | undefined; - try { - serverTransport = await createTransport(req); - } catch (error) { - if (error instanceof SseError && error.code === 401) { - console.error( - "Received 401 Unauthorized from MCP server. Authentication failure.", - ); - res.status(401).json(error); - return; - } - - throw error; - } + const { transport: serverTransport } = await createTransport(req); const proxyFullAddress = (req.query.proxyFullAddress as string) || ""; const prefix = proxyFullAddress || ""; @@ -422,6 +496,7 @@ app.get( serverTransport.close(); webAppTransports.delete(webAppTransport.sessionId); serverTransports.delete(webAppTransport.sessionId); + sessionHeaderHolders.delete(webAppTransport.sessionId); console.error(message); } else { // Inspect message and attempt to assign a RFC 5424 Syslog Protocol level @@ -475,6 +550,13 @@ app.get( transportToServer: serverTransport, }); } catch (error) { + if (error instanceof SseError && error.code === 401) { + console.error( + "Received 401 Unauthorized from MCP server. Authentication failure.", + ); + res.status(401).json(error); + return; + } console.error("Error in /stdio route:", error); res.status(500).json(error); } @@ -490,50 +572,46 @@ app.get( console.log( "New SSE connection request. NOTE: The SSE transport is deprecated and has been replaced by StreamableHttp", ); - let serverTransport: Transport | undefined; - try { - serverTransport = await createTransport(req); - } catch (error) { - if (error instanceof SseError && error.code === 401) { - console.error( - "Received 401 Unauthorized from MCP server. Authentication failure.", - ); - res.status(401).json(error); - return; - } else if (error instanceof SseError && error.code === 404) { - console.error( - "Received 404 not found from MCP server. Does the MCP server support SSE?", - ); - res.status(404).json(error); - return; - } else if (JSON.stringify(error).includes("ECONNREFUSED")) { - console.error("Connection refused. Is the MCP server running?"); - res.status(500).json(error); - } else { - throw error; - } - } + const { transport: serverTransport, headerHolder } = + await createTransport(req); - if (serverTransport) { - const proxyFullAddress = (req.query.proxyFullAddress as string) || ""; - const prefix = proxyFullAddress || ""; - const endpoint = `${prefix}/message`; + const proxyFullAddress = (req.query.proxyFullAddress as string) || ""; + const prefix = proxyFullAddress || ""; + const endpoint = `${prefix}/message`; - const webAppTransport = new SSEServerTransport(endpoint, res); - webAppTransports.set(webAppTransport.sessionId, webAppTransport); - console.log("Created client transport"); + const webAppTransport = new SSEServerTransport(endpoint, res); + webAppTransports.set(webAppTransport.sessionId, webAppTransport); + console.log("Created client transport"); - serverTransports.set(webAppTransport.sessionId, serverTransport!); - console.log("Created server transport"); + serverTransports.set(webAppTransport.sessionId, serverTransport!); // eslint-disable-line @typescript-eslint/no-non-null-assertion + if (headerHolder) { + sessionHeaderHolders.set(webAppTransport.sessionId, headerHolder); + } + console.log("Created server transport"); - await webAppTransport.start(); + await webAppTransport.start(); - mcpProxy({ - transportToClient: webAppTransport, - transportToServer: serverTransport, - }); - } + mcpProxy({ + transportToClient: webAppTransport, + transportToServer: serverTransport, + }); } catch (error) { + if (error instanceof SseError && error.code === 401) { + console.error( + "Received 401 Unauthorized from MCP server. Authentication failure.", + ); + res.status(401).json(error); + return; + } else if (error instanceof SseError && error.code === 404) { + console.error( + "Received 404 not found from MCP server. Does the MCP server support SSE?", + ); + res.status(404).json(error); + return; + } else if (JSON.stringify(error).includes("ECONNREFUSED")) { + console.error("Connection refused. Is the MCP server running?"); + res.status(500).json(error); + } console.error("Error in /sse route:", error); res.status(500).json(error); } @@ -546,12 +624,18 @@ app.post( authMiddleware, async (req, res) => { try { - const sessionId = req.query.sessionId; + const sessionId = req.query.sessionId as string; console.log(`Received POST message for sessionId ${sessionId}`); - const transport = webAppTransports.get( - sessionId as string, - ) as SSEServerTransport; + const headerHolder = sessionHeaderHolders.get(sessionId); + if (headerHolder) { + updateHeadersInPlace( + headerHolder.headers as Record, + getHttpHeaders(req), + ); + } + + const transport = webAppTransports.get(sessionId) as SSEServerTransport; if (!transport) { res.status(404).end("Session not found"); return;