diff --git a/router-tests/mcp_test.go b/router-tests/mcp_test.go index 0f21bd16af..be7cc60521 100644 --- a/router-tests/mcp_test.go +++ b/router-tests/mcp_test.go @@ -874,5 +874,112 @@ input UserInput { // ultimately reaches the subgraphs. }) }) + + t.Run("Hop-by-hop and filtered headers are not forwarded", func(t *testing.T) { + var capturedSubgraphRequest *http.Request + var subgraphMutex sync.Mutex + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + Session: config.MCPSessionConfig{ + Stateless: true, + }, + }, + RouterOptions: []core.Option{ + // Forward all headers + core.WithHeaderRules(config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ + { + Operation: config.HeaderRuleOperationPropagate, + Matching: ".*", + }, + }, + }, + }), + }, + Subgraphs: testenv.SubgraphsConfig{ + GlobalMiddleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + subgraphMutex.Lock() + capturedSubgraphRequest = r.Clone(r.Context()) + subgraphMutex.Unlock() + handler.ServeHTTP(w, r) + }) + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + mcpAddr := xEnv.GetMCPServerAddr() + + mcpRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "execute_operation_my_employees", + "arguments": map[string]interface{}{ + "criteria": map[string]interface{}{}, + }, + }, + } + + requestBody, err := json.Marshal(mcpRequest) + require.NoError(t, err) + + req, err := http.NewRequest("POST", mcpAddr, strings.NewReader(string(requestBody))) + require.NoError(t, err) + + // Set headers that should be filtered + req.Header.Set("Proxy-Authenticate", "Basic") + req.Header.Set("Proxy-Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l") + req.Header.Set("Content-Type", "application/json; foo=bar") // Custom param that should be stripped + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Encoding", "br") // Request brotli (which go client doesn't support by default) + req.Header.Set("Alt-Svc", "h2=\":443\"; ma=2592000") + req.Header.Set("Proxy-Connection", "keep-alive") + + // Set a header that SHOULD be forwarded for control + req.Header.Set("X-Allowed-Header", "allowed") + + resp, err := xEnv.RouterClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Logf("Response Status: %d", resp.StatusCode) + } + require.Equal(t, http.StatusOK, resp.StatusCode) + + subgraphMutex.Lock() + defer subgraphMutex.Unlock() + + require.NotNil(t, capturedSubgraphRequest) + + // Check control header + assert.Equal(t, "allowed", capturedSubgraphRequest.Header.Get("X-Allowed-Header")) + + // Check filtered headers + assert.NotEqual(t, "Basic", capturedSubgraphRequest.Header.Get("Proxy-Authenticate")) + assert.NotEqual(t, "Basic YWxhZGRpbjpvcGVuc2VzYW1l", capturedSubgraphRequest.Header.Get("Proxy-Authorization")) + + // Content-Type should be set by MCP server to application/json (and stripped of custom params) + ct := capturedSubgraphRequest.Header.Get("Content-Type") + assert.True(t, strings.HasPrefix(ct, "application/json"), "Content-Type should start with application/json") + assert.False(t, strings.Contains(ct, "foo=bar"), "Content-Type should not contain forwarded parameters") + + // Accept should be set by MCP server + assert.Equal(t, "application/json", capturedSubgraphRequest.Header.Get("Accept")) + + // Accept-Encoding should be set by the Go HTTP client (gzip), not what we sent (br) + ae := capturedSubgraphRequest.Header.Get("Accept-Encoding") + assert.Contains(t, ae, "gzip", "Accept-Encoding should contain gzip (set by Go client)") + assert.NotContains(t, ae, "br", "Accept-Encoding should not contain br (filtered from request)") + + // Other headers should be missing + assert.Empty(t, capturedSubgraphRequest.Header.Get("Alt-Svc")) + assert.Empty(t, capturedSubgraphRequest.Header.Get("Proxy-Connection")) + }) + }) }) } diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index fc5812869e..d1f8802b3c 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -37,6 +37,29 @@ func requestHeadersFromRequest(ctx context.Context, r *http.Request) context.Con return withRequestHeaders(ctx, headers) } +var skippedHeaders = map[string]struct{}{ + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Te": {}, + "Trailer": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, + "Host": {}, + "Content-Length": {}, + "Content-Type": {}, + "Accept": {}, + "Accept-Encoding": {}, + "Accept-Charset": {}, + "Alt-Svc": {}, + "Proxy-Connection": {}, + "Sec-Websocket-Extensions": {}, + "Sec-Websocket-Key": {}, + "Sec-Websocket-Protocol": {}, + "Sec-Websocket-Version": {}, +} + // headersFromContext extracts the request headers from the context. func headersFromContext(ctx context.Context) (http.Header, error) { headers, ok := ctx.Value(requestHeadersKey{}).(http.Header) @@ -686,6 +709,10 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str } else { // Copy all headers from the MCP request for key, values := range headers { + // Skip headers that should not be forwarded + if _, ok := skippedHeaders[key]; ok { + continue + } for _, value := range values { req.Header.Add(key, value) }