Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion router-tests/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"testing"

"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
"github.com/wundergraph/cosmo/router/pkg/config"
)

Expand Down Expand Up @@ -212,7 +214,6 @@ func TestMCP(t *testing.T) {
t.Run("Execute Query", func(t *testing.T) {
t.Run("Execute operation of type query with valid input", func(t *testing.T) {
testenv.Run(t, &testenv.Config{
EnableNats: true,
MCP: config.MCPConfiguration{
Enabled: true,
},
Expand Down Expand Up @@ -553,4 +554,124 @@ func TestMCP(t *testing.T) {
})
})
})

t.Run("Header Forwarding", func(t *testing.T) {
t.Run("All request headers are forwarded from MCP client through to subgraphs", func(t *testing.T) {
// This test validates that ALL headers sent by MCP clients are forwarded
// through the complete chain: MCP Client -> MCP Server -> Router -> Subgraphs
//
// The router's header forwarding rules (configured with wildcard `.*`) determine
// what gets propagated to subgraphs. The MCP server acts as a transparent proxy,
// forwarding all headers without filtering.
//
// Note: We use direct HTTP POST requests instead of the mcp-go client library
// because transport.WithHTTPHeaders() in mcp-go sets headers at the SSE connection
// level, not on individual tool execution requests. Direct HTTP requests allow us
// to test per-request headers, which is what real MCP clients (like Claude Desktop) send.

var capturedSubgraphRequest *http.Request
var subgraphMutex sync.Mutex

testenv.Run(t, &testenv.Config{
MCP: config.MCPConfiguration{
Enabled: true,
Session: config.MCPSessionConfig{
Stateless: true, // Enable stateless mode so we don't need session IDs
},
},
RouterOptions: []core.Option{
// Forward all headers including custom ones
core.WithHeaderRules(config.HeaderRules{
All: &config.GlobalHeaderRule{
Request: []*config.RequestHeaderRule{
{
Operation: config.HeaderRuleOperationPropagate,
Matching: ".*", // Forward all headers
},
},
},
}),
},
Subgraphs: testenv.SubgraphsConfig{
GlobalMiddleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
subgraphMutex.Lock()
capturedSubgraphRequest = r.Clone(r.Context())
subgraphMutex.Unlock()
handler.ServeHTTP(w, r)
})
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// With stateless mode enabled, we can make direct HTTP POST requests
// without needing to establish a session first
mcpAddr := xEnv.GetMCPServerAddr()

// Make a direct HTTP POST request with custom headers
// This simulates a real MCP client sending custom headers on tool calls
mcpRequest := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": map[string]interface{}{
"name": "execute_operation_my_employees",
"arguments": map[string]interface{}{
"criteria": map[string]interface{}{},
},
},
}

requestBody, err := json.Marshal(mcpRequest)
require.NoError(t, err)

req, err := http.NewRequest("POST", mcpAddr, strings.NewReader(string(requestBody)))
require.NoError(t, err)

// Add various headers to test forwarding
req.Header.Set("Content-Type", "application/json")
req.Header.Set("foo", "bar") // Non-standard header
req.Header.Set("X-Custom-Header", "custom-value") // Custom X- header
req.Header.Set("X-Trace-Id", "trace-123") // Tracing header
req.Header.Set("Authorization", "Bearer test-token") // Auth header

// Make the request
resp, err := xEnv.RouterClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// With stateless mode, the request should succeed
t.Logf("Response Status: %d", resp.StatusCode)
require.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in stateless mode")

// Verify headers reached subgraph
subgraphMutex.Lock()
defer subgraphMutex.Unlock()

require.NotNil(t, capturedSubgraphRequest, "Subgraph should have received a request")

// Log all headers that the subgraph received
t.Logf("Headers received by subgraph:")
for key, values := range capturedSubgraphRequest.Header {
for _, value := range values {
t.Logf(" %s: %s", key, value)
}
}

// Verify that all headers were forwarded through the entire chain:
// MCP Client -> MCP Server -> Router -> Subgraph
assert.Equal(t, "bar", capturedSubgraphRequest.Header.Get("Foo"),
"'foo' header should be forwarded to subgraph")
assert.Equal(t, "custom-value", capturedSubgraphRequest.Header.Get("X-Custom-Header"),
"X-Custom-Header should be forwarded to subgraph")
assert.Equal(t, "trace-123", capturedSubgraphRequest.Header.Get("X-Trace-Id"),
"X-Trace-Id should be forwarded to subgraph")
assert.Equal(t, "Bearer test-token", capturedSubgraphRequest.Header.Get("Authorization"),
"Authorization header should be forwarded to subgraph")

// This test proves that ALL headers sent by MCP clients are forwarded
// through the complete chain. The router's header rules determine what
// ultimately reaches the subgraphs.
})
})
})
}
58 changes: 35 additions & 23 deletions router/pkg/mcpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,28 @@ import (
"go.uber.org/zap"
)

// authKey is a custom context key for storing the auth token.
type authKey struct{}
// requestHeadersKey is a custom context key for storing request headers.
type requestHeadersKey struct{}

// withAuthKey adds an auth key to the context.
func withAuthKey(ctx context.Context, auth string) context.Context {
return context.WithValue(ctx, authKey{}, auth)
// withRequestHeaders adds request headers to the context.
func withRequestHeaders(ctx context.Context, headers http.Header) context.Context {
return context.WithValue(ctx, requestHeadersKey{}, headers)
}

// authFromRequest extracts the auth token from the request headers.
func authFromRequest(ctx context.Context, r *http.Request) context.Context {
return withAuthKey(ctx, r.Header.Get("Authorization"))
// requestHeadersFromRequest extracts all headers from the request and stores them in context.
func requestHeadersFromRequest(ctx context.Context, r *http.Request) context.Context {
// Clone the headers to avoid any mutation issues
headers := r.Header.Clone()
return withRequestHeaders(ctx, headers)
}

// tokenFromContext extracts the auth token from the context.
// This can be used by clients to pass the auth token to the server.
func tokenFromContext(ctx context.Context) (string, error) {
auth, ok := ctx.Value(authKey{}).(string)
// headersFromContext extracts the request headers from the context.
func headersFromContext(ctx context.Context) (http.Header, error) {
headers, ok := ctx.Value(requestHeadersKey{}).(http.Header)
if !ok {
return "", fmt.Errorf("missing auth")
return nil, fmt.Errorf("missing request headers")
}
return auth, nil
return headers, nil
}

// Options represents configuration options for the GraphQLSchemaServer
Expand Down Expand Up @@ -223,6 +224,11 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options)
return gs, nil
}

// SetHTTPClient allows setting a custom HTTP client (useful for testing)
func (s *GraphQLSchemaServer) SetHTTPClient(client *http.Client) {
s.httpClient = client
}

// WithGraphName sets the graph name
func WithGraphName(graphName string) func(*Options) {
return func(o *Options) {
Expand Down Expand Up @@ -299,7 +305,7 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) {
server.WithStreamableHTTPServer(httpServer),
server.WithLogger(NewZapAdapter(s.logger.With(zap.String("component", "mcp-server")))),
server.WithStateLess(s.stateless),
server.WithHTTPContextFunc(authFromRequest),
server.WithHTTPContextFunc(requestHeadersFromRequest),
server.WithHeartbeatInterval(10*time.Second),
)

Expand Down Expand Up @@ -672,17 +678,23 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json; charset=utf-8")

token, err := tokenFromContext(ctx)
// Forward all headers from the original MCP request to the GraphQL server
// The router's header forwarding rules will then determine what gets sent to subgraphs
headers, err := headersFromContext(ctx)
if err != nil {
s.logger.Debug("failed to get token from context", zap.Error(err))
} else if token != "" {
req.Header.Set("Authorization", token)
s.logger.Debug("failed to get headers from context", zap.Error(err))
} else {
// Copy all headers from the MCP request
for key, values := range headers {
for _, value := range values {
req.Header.Add(key, value)
}
}
}

// Forward Authorization header if provided
// Override specific headers that must be set for GraphQL requests
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json; charset=utf-8")

resp, err := s.httpClient.Do(req)
if err != nil {
Expand Down
Loading