Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions router-tests/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,5 +874,112 @@ input UserInput {
// ultimately reaches the subgraphs.
})
})

t.Run("Hop-by-hop and filtered headers are not forwarded", func(t *testing.T) {
var capturedSubgraphRequest *http.Request
var subgraphMutex sync.Mutex

testenv.Run(t, &testenv.Config{
MCP: config.MCPConfiguration{
Enabled: true,
Session: config.MCPSessionConfig{
Stateless: true,
},
},
RouterOptions: []core.Option{
// Forward all headers
core.WithHeaderRules(config.HeaderRules{
All: &config.GlobalHeaderRule{
Request: []*config.RequestHeaderRule{
{
Operation: config.HeaderRuleOperationPropagate,
Matching: ".*",
},
},
},
}),
},
Subgraphs: testenv.SubgraphsConfig{
GlobalMiddleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
subgraphMutex.Lock()
capturedSubgraphRequest = r.Clone(r.Context())
subgraphMutex.Unlock()
handler.ServeHTTP(w, r)
})
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
mcpAddr := xEnv.GetMCPServerAddr()

mcpRequest := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": map[string]interface{}{
"name": "execute_operation_my_employees",
"arguments": map[string]interface{}{
"criteria": map[string]interface{}{},
},
},
}

requestBody, err := json.Marshal(mcpRequest)
require.NoError(t, err)

req, err := http.NewRequest("POST", mcpAddr, strings.NewReader(string(requestBody)))
require.NoError(t, err)

// Set headers that should be filtered
req.Header.Set("Proxy-Authenticate", "Basic")
req.Header.Set("Proxy-Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l")
req.Header.Set("Content-Type", "application/json; foo=bar") // Custom param that should be stripped
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Encoding", "br") // Request brotli (which go client doesn't support by default)
req.Header.Set("Alt-Svc", "h2=\":443\"; ma=2592000")
req.Header.Set("Proxy-Connection", "keep-alive")

// Set a header that SHOULD be forwarded for control
req.Header.Set("X-Allowed-Header", "allowed")

resp, err := xEnv.RouterClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Logf("Response Status: %d", resp.StatusCode)
}
require.Equal(t, http.StatusOK, resp.StatusCode)

subgraphMutex.Lock()
defer subgraphMutex.Unlock()

require.NotNil(t, capturedSubgraphRequest)

// Check control header
assert.Equal(t, "allowed", capturedSubgraphRequest.Header.Get("X-Allowed-Header"))

// Check filtered headers
assert.NotEqual(t, "Basic", capturedSubgraphRequest.Header.Get("Proxy-Authenticate"))
assert.NotEqual(t, "Basic YWxhZGRpbjpvcGVuc2VzYW1l", capturedSubgraphRequest.Header.Get("Proxy-Authorization"))

// Content-Type should be set by MCP server to application/json (and stripped of custom params)
ct := capturedSubgraphRequest.Header.Get("Content-Type")
assert.True(t, strings.HasPrefix(ct, "application/json"), "Content-Type should start with application/json")
assert.False(t, strings.Contains(ct, "foo=bar"), "Content-Type should not contain forwarded parameters")

// Accept should be set by MCP server
assert.Equal(t, "application/json", capturedSubgraphRequest.Header.Get("Accept"))

// Accept-Encoding should be set by the Go HTTP client (gzip), not what we sent (br)
ae := capturedSubgraphRequest.Header.Get("Accept-Encoding")
assert.Contains(t, ae, "gzip", "Accept-Encoding should contain gzip (set by Go client)")
assert.NotContains(t, ae, "br", "Accept-Encoding should not contain br (filtered from request)")

// Other headers should be missing
assert.Empty(t, capturedSubgraphRequest.Header.Get("Alt-Svc"))
assert.Empty(t, capturedSubgraphRequest.Header.Get("Proxy-Connection"))
})
})
})
}
27 changes: 27 additions & 0 deletions router/pkg/mcpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ func requestHeadersFromRequest(ctx context.Context, r *http.Request) context.Con
return withRequestHeaders(ctx, headers)
}

var skippedHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
"Host": {},
"Content-Length": {},
"Content-Type": {},
"Accept": {},
"Accept-Encoding": {},
"Accept-Charset": {},
"Alt-Svc": {},
"Proxy-Connection": {},
"Sec-Websocket-Extensions": {},
"Sec-Websocket-Key": {},
"Sec-Websocket-Protocol": {},
"Sec-Websocket-Version": {},
}

// headersFromContext extracts the request headers from the context.
func headersFromContext(ctx context.Context) (http.Header, error) {
headers, ok := ctx.Value(requestHeadersKey{}).(http.Header)
Expand Down Expand Up @@ -686,6 +709,10 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str
} else {
// Copy all headers from the MCP request
for key, values := range headers {
// Skip headers that should not be forwarded
if _, ok := skippedHeaders[key]; ok {
continue
}
for _, value := range values {
req.Header.Add(key, value)
}
Expand Down
Loading