diff --git a/router-tests/mcp_test.go b/router-tests/mcp_test.go index 83ef5a8ae5..ffe2ddf082 100644 --- a/router-tests/mcp_test.go +++ b/router-tests/mcp_test.go @@ -5,12 +5,14 @@ import ( "fmt" "net/http" "strings" + "sync" "testing" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" "github.com/wundergraph/cosmo/router/pkg/config" ) @@ -212,7 +214,6 @@ func TestMCP(t *testing.T) { t.Run("Execute Query", func(t *testing.T) { t.Run("Execute operation of type query with valid input", func(t *testing.T) { testenv.Run(t, &testenv.Config{ - EnableNats: true, MCP: config.MCPConfiguration{ Enabled: true, }, @@ -553,4 +554,124 @@ func TestMCP(t *testing.T) { }) }) }) + + t.Run("Header Forwarding", func(t *testing.T) { + t.Run("All request headers are forwarded from MCP client through to subgraphs", func(t *testing.T) { + // This test validates that ALL headers sent by MCP clients are forwarded + // through the complete chain: MCP Client -> MCP Server -> Router -> Subgraphs + // + // The router's header forwarding rules (configured with wildcard `.*`) determine + // what gets propagated to subgraphs. The MCP server acts as a transparent proxy, + // forwarding all headers without filtering. + // + // Note: We use direct HTTP POST requests instead of the mcp-go client library + // because transport.WithHTTPHeaders() in mcp-go sets headers at the SSE connection + // level, not on individual tool execution requests. Direct HTTP requests allow us + // to test per-request headers, which is what real MCP clients (like Claude Desktop) send. + + var capturedSubgraphRequest *http.Request + var subgraphMutex sync.Mutex + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + Session: config.MCPSessionConfig{ + Stateless: true, // Enable stateless mode so we don't need session IDs + }, + }, + RouterOptions: []core.Option{ + // Forward all headers including custom ones + core.WithHeaderRules(config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ + { + Operation: config.HeaderRuleOperationPropagate, + Matching: ".*", // Forward all headers + }, + }, + }, + }), + }, + Subgraphs: testenv.SubgraphsConfig{ + GlobalMiddleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + subgraphMutex.Lock() + capturedSubgraphRequest = r.Clone(r.Context()) + subgraphMutex.Unlock() + handler.ServeHTTP(w, r) + }) + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // With stateless mode enabled, we can make direct HTTP POST requests + // without needing to establish a session first + mcpAddr := xEnv.GetMCPServerAddr() + + // Make a direct HTTP POST request with custom headers + // This simulates a real MCP client sending custom headers on tool calls + mcpRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "execute_operation_my_employees", + "arguments": map[string]interface{}{ + "criteria": map[string]interface{}{}, + }, + }, + } + + requestBody, err := json.Marshal(mcpRequest) + require.NoError(t, err) + + req, err := http.NewRequest("POST", mcpAddr, strings.NewReader(string(requestBody))) + require.NoError(t, err) + + // Add various headers to test forwarding + req.Header.Set("Content-Type", "application/json") + req.Header.Set("foo", "bar") // Non-standard header + req.Header.Set("X-Custom-Header", "custom-value") // Custom X- header + req.Header.Set("X-Trace-Id", "trace-123") // Tracing header + req.Header.Set("Authorization", "Bearer test-token") // Auth header + + // Make the request + resp, err := xEnv.RouterClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // With stateless mode, the request should succeed + t.Logf("Response Status: %d", resp.StatusCode) + require.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in stateless mode") + + // Verify headers reached subgraph + subgraphMutex.Lock() + defer subgraphMutex.Unlock() + + require.NotNil(t, capturedSubgraphRequest, "Subgraph should have received a request") + + // Log all headers that the subgraph received + t.Logf("Headers received by subgraph:") + for key, values := range capturedSubgraphRequest.Header { + for _, value := range values { + t.Logf(" %s: %s", key, value) + } + } + + // Verify that all headers were forwarded through the entire chain: + // MCP Client -> MCP Server -> Router -> Subgraph + assert.Equal(t, "bar", capturedSubgraphRequest.Header.Get("Foo"), + "'foo' header should be forwarded to subgraph") + assert.Equal(t, "custom-value", capturedSubgraphRequest.Header.Get("X-Custom-Header"), + "X-Custom-Header should be forwarded to subgraph") + assert.Equal(t, "trace-123", capturedSubgraphRequest.Header.Get("X-Trace-Id"), + "X-Trace-Id should be forwarded to subgraph") + assert.Equal(t, "Bearer test-token", capturedSubgraphRequest.Header.Get("Authorization"), + "Authorization header should be forwarded to subgraph") + + // This test proves that ALL headers sent by MCP clients are forwarded + // through the complete chain. The router's header rules determine what + // ultimately reaches the subgraphs. + }) + }) + }) } diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index c5cf5b6399..2140966de8 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -22,27 +22,28 @@ import ( "go.uber.org/zap" ) -// authKey is a custom context key for storing the auth token. -type authKey struct{} +// requestHeadersKey is a custom context key for storing request headers. +type requestHeadersKey struct{} -// withAuthKey adds an auth key to the context. -func withAuthKey(ctx context.Context, auth string) context.Context { - return context.WithValue(ctx, authKey{}, auth) +// withRequestHeaders adds request headers to the context. +func withRequestHeaders(ctx context.Context, headers http.Header) context.Context { + return context.WithValue(ctx, requestHeadersKey{}, headers) } -// authFromRequest extracts the auth token from the request headers. -func authFromRequest(ctx context.Context, r *http.Request) context.Context { - return withAuthKey(ctx, r.Header.Get("Authorization")) +// requestHeadersFromRequest extracts all headers from the request and stores them in context. +func requestHeadersFromRequest(ctx context.Context, r *http.Request) context.Context { + // Clone the headers to avoid any mutation issues + headers := r.Header.Clone() + return withRequestHeaders(ctx, headers) } -// tokenFromContext extracts the auth token from the context. -// This can be used by clients to pass the auth token to the server. -func tokenFromContext(ctx context.Context) (string, error) { - auth, ok := ctx.Value(authKey{}).(string) +// headersFromContext extracts the request headers from the context. +func headersFromContext(ctx context.Context) (http.Header, error) { + headers, ok := ctx.Value(requestHeadersKey{}).(http.Header) if !ok { - return "", fmt.Errorf("missing auth") + return nil, fmt.Errorf("missing request headers") } - return auth, nil + return headers, nil } // Options represents configuration options for the GraphQLSchemaServer @@ -223,6 +224,11 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) return gs, nil } +// SetHTTPClient allows setting a custom HTTP client (useful for testing) +func (s *GraphQLSchemaServer) SetHTTPClient(client *http.Client) { + s.httpClient = client +} + // WithGraphName sets the graph name func WithGraphName(graphName string) func(*Options) { return func(o *Options) { @@ -299,7 +305,7 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { server.WithStreamableHTTPServer(httpServer), server.WithLogger(NewZapAdapter(s.logger.With(zap.String("component", "mcp-server")))), server.WithStateLess(s.stateless), - server.WithHTTPContextFunc(authFromRequest), + server.WithHTTPContextFunc(requestHeadersFromRequest), server.WithHeartbeatInterval(10*time.Second), ) @@ -672,17 +678,23 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json; charset=utf-8") - - token, err := tokenFromContext(ctx) + // Forward all headers from the original MCP request to the GraphQL server + // The router's header forwarding rules will then determine what gets sent to subgraphs + headers, err := headersFromContext(ctx) if err != nil { - s.logger.Debug("failed to get token from context", zap.Error(err)) - } else if token != "" { - req.Header.Set("Authorization", token) + s.logger.Debug("failed to get headers from context", zap.Error(err)) + } else { + // Copy all headers from the MCP request + for key, values := range headers { + for _, value := range values { + req.Header.Add(key, value) + } + } } - // Forward Authorization header if provided + // Override specific headers that must be set for GraphQL requests + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json; charset=utf-8") resp, err := s.httpClient.Do(req) if err != nil {