diff --git a/router-tests/header_propagation_test.go b/router-tests/header_propagation_test.go index c44000d31c..fa60030ac3 100644 --- a/router-tests/header_propagation_test.go +++ b/router-tests/header_propagation_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "go.uber.org/zap/zapcore" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" @@ -1156,4 +1158,359 @@ func TestHeaderPropagation(t *testing.T) { }) }) }) + + t.Run("Router Response Header Rules", func(t *testing.T) { + t.Parallel() + + t.Run("should set router response headers from static expressions", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Static-Header", + Expression: `"static-value"`, + }, + { + Name: "X-Another-Header", + Expression: `"another-value"`, + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithNoHobby, + }) + require.Equal(t, "static-value", res.Response.Header.Get("X-Static-Header")) + require.Equal(t, "another-value", res.Response.Header.Get("X-Another-Header")) + }) + }) + + t.Run("should set router response headers from request headers", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Echo-Header", + Expression: `request.header.Get("X-Custom-Input")`, + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: queryEmployeeWithNoHobby, + Header: map[string][]string{ + "X-Custom-Input": {"input-value"}, + }, + }) + require.NoError(t, err) + require.Equal(t, "input-value", res.Response.Header.Get("X-Echo-Header")) + }) + }) + + t.Run("should work alongside response header propagation", func(t *testing.T) { + t.Parallel() + + t.Run("when there is a separate header", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Response: []*config.ResponseHeaderRule{ + { + Operation: config.HeaderRuleOperationPropagate, + Named: "X-Custom-Header", + Algorithm: config.ResponseHeaderRuleAlgorithmFirstWrite, + }, + }, + }, + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Client-Header", + Expression: `"client-value"`, + }, + }, + }, + }), + }, + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + // Check that both router response header and propagated response header are present + require.Equal(t, "client-value", res.Response.Header.Get("X-Client-Header")) + require.Equal(t, employeeVal, res.Response.Header.Get("X-Custom-Header")) + }) + }) + + t.Run("when the same header is in use", func(t *testing.T) { + t.Parallel() + + t.Run("ensure router response header overrides", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Response: []*config.ResponseHeaderRule{ + { + Operation: config.HeaderRuleOperationPropagate, + Named: "X-Custom-Header", + Algorithm: config.ResponseHeaderRuleAlgorithmFirstWrite, + }, + }, + }, + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Custom-Header", + Expression: `"client-value"`, + }, + }, + }, + }), + }, + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + require.Equal(t, "client-value", res.Response.Header.Get("X-Custom-Header")) + }) + }) + }) + }) + + t.Run("should work alongside request header propagation", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ + { + Operation: config.HeaderRuleOperationPropagate, + Named: "X-Request-Header", + }, + }, + }, + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Router-Header", + Expression: `request.header.Get("X-Request-Header")`, + }, + }, + }, + }), + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the request header was propagated to the subgraph + require.Equal(t, "request-value", r.Header.Get("X-Request-Header")) + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: queryEmployeeWithNoHobby, + Header: map[string][]string{ + "X-Request-Header": {"request-value"}, + }, + }) + require.NoError(t, err) + require.Equal(t, "request-value", res.Response.Header.Get("X-Router-Header")) + }) + }) + + t.Run("should work alongside both request and response header propagation", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ + { + Operation: config.HeaderRuleOperationPropagate, + Named: "X-Request-Header", + }, + }, + Response: []*config.ResponseHeaderRule{ + { + Operation: config.HeaderRuleOperationPropagate, + Named: "X-Verification", + Algorithm: config.ResponseHeaderRuleAlgorithmFirstWrite, + }, + }, + }, + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Router-Header", + Expression: `request.header.Get("X-Request-Header")`, + }, + }, + }, + }), + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the request header was propagated to the subgraph + if r.Header.Get("X-Request-Header") == "request-value" { + w.Header().Set("X-Verification", "passed") + } + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: queryEmployeeWithNoHobby, + Header: map[string][]string{ + "X-Request-Header": {"request-value"}, + }, + }) + require.NoError(t, err) + require.Equal(t, "passed", res.Response.Header.Get("X-Verification")) + require.Equal(t, "request-value", res.Response.Header.Get("X-Router-Header")) + }) + }) + + t.Run("should ignore rules that resolve to empty string", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Empty-Header", + Expression: `""`, + }, + { + Name: "X-Missing-Header", + Expression: `request.header.Get("X-Does-Not-Exist")`, + }, + { + Name: "X-Valid-Header", + Expression: `"valid-value"`, + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithNoHobby, + }) + // Empty headers should not be set + require.Equal(t, "", res.Response.Header.Get("X-Empty-Header")) + require.Equal(t, "", res.Response.Header.Get("X-Missing-Header")) + // Valid header should be set + require.Equal(t, "valid-value", res.Response.Header.Get("X-Valid-Header")) + }) + }) + + t.Run("should work with errors in response", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Router-Header", + Expression: `"router-value"`, + }, + { + Name: "X-Error-Header", + Expression: `request.error != nil ? "error" : "success"`, + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `{ employee(id: 1) { id rootFieldThrowsError } }`, + }) + // Router response header should still be set even with errors + require.Equal(t, "router-value", res.Response.Header.Get("X-Router-Header")) + require.Equal(t, "error", res.Response.Header.Get("X-Error-Header")) + }) + }) + + t.Run("should log errors (but not error out) when router response header rule evaluation fails at runtime", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.ErrorLevel, + }, + RouterOptions: []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Valid-Header", + Expression: `"valid-value"`, + }, + { + Name: "X-Invalid-Header", + Expression: `string(int("a"))`, + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithNoHobby, + }) + + require.Equal(t, "valid-value", res.Response.Header.Get("X-Valid-Header")) + + _, headerExists := res.Response.Header["X-Invalid-Header"] + require.False(t, headerExists) + + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Contains(t, res.Body, `"data"`) + + logs := xEnv.Observer() + require.NotNil(t, logs) + + errorLogs := logs.FilterMessage("Failed to apply router response header rules").All() + require.Len(t, errorLogs, 1) + + errorLog := errorLogs[0] + require.Equal(t, zapcore.ErrorLevel, errorLog.Level) + require.Equal(t, "Failed to apply router response header rules", errorLog.Message) + require.NotEmpty(t, errorLog.Context) + }) + }) + }) } diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 87aa96331c..d56cec4836 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1389,6 +1389,7 @@ func (s *graphServer) buildGraphMux( Log: s.logger, EnableCacheResponseHeaders: s.engineExecutionConfiguration.Debug.EnableCacheResponseHeaders, EnableResponseHeaderPropagation: s.headerRules != nil, + HeaderPropagation: s.headerPropagation, EngineStats: s.engineStats, TracerProvider: s.tracerProvider, Authorizer: NewCosmoAuthorizer(authorizerOptions), diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index f5475997de..fba1967894 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -79,6 +79,7 @@ type HandlerOptions struct { EnableCacheResponseHeaders bool EnableResponseHeaderPropagation bool + HeaderPropagation *HeaderPropagation ApolloSubscriptionMultipartPrintBoundary bool } @@ -92,6 +93,7 @@ func NewGraphQLHandler(opts HandlerOptions) *GraphQLHandler { executor: opts.Executor, enableCacheResponseHeaders: opts.EnableCacheResponseHeaders, enableResponseHeaderPropagation: opts.EnableResponseHeaderPropagation, + headerPropagation: opts.HeaderPropagation, engineStats: opts.EngineStats, tracer: tracer, authorizer: opts.Authorizer, @@ -128,6 +130,7 @@ type GraphQLHandler struct { enableCacheResponseHeaders bool enableResponseHeaderPropagation bool + headerPropagation *HeaderPropagation apolloSubscriptionMultipartPrintBoundary bool } @@ -190,6 +193,13 @@ func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate") } + // Apply any final router response header rules + if h.headerPropagation != nil { + if err := h.headerPropagation.ApplyRouterResponseHeaderRules(w, reqCtx); err != nil { + reqCtx.logger.Error("Failed to apply router response header rules", zap.Error(err)) + } + } + // Write contents of buf to the header propagation writer hpw := HeaderPropagationWriter(w, resolveCtx.Context()) _, err = respBuf.WriteTo(hpw) diff --git a/router/core/header_rule_engine.go b/router/core/header_rule_engine.go index 51af54b39e..faae1c1cb0 100644 --- a/router/core/header_rule_engine.go +++ b/router/core/header_rule_engine.go @@ -121,11 +121,12 @@ func (h *headerPropagationWriter) Write(p []byte) (n int, err error) { // HeaderPropagation is a pre-origin handler that can be used to propagate and // manipulate headers from the client request to the upstream type HeaderPropagation struct { - regex map[string]*regexp.Regexp - rules *config.HeaderRules - compiledRules map[string]*vm.Program - hasRequestRules bool - hasResponseRules bool + regex map[string]*regexp.Regexp + rules *config.HeaderRules + compiledRequestRules map[string]*vm.Program + compiledRouterResponseRules map[string]*vm.Program + hasRequestRules bool + hasResponseRules bool } func initHeaderRules(rules *config.HeaderRules) { @@ -144,12 +145,13 @@ func NewHeaderPropagation(rules *config.HeaderRules) (*HeaderPropagation, error) initHeaderRules(rules) hf := HeaderPropagation{ - rules: rules, - regex: map[string]*regexp.Regexp{}, - compiledRules: map[string]*vm.Program{}, + rules: rules, + regex: map[string]*regexp.Regexp{}, + compiledRequestRules: map[string]*vm.Program{}, + compiledRouterResponseRules: map[string]*vm.Program{}, } - rhrs, rhrrs := hf.getAllRules() + rhrs, rhrrs, rrs := hf.getAllRules() hf.hasRequestRules = len(rhrs) > 0 hf.hasResponseRules = len(rhrrs) > 0 @@ -157,7 +159,7 @@ func NewHeaderPropagation(rules *config.HeaderRules) (*HeaderPropagation, error) return nil, err } - if err := hf.compileExpressionRules(rhrs); err != nil { + if err := hf.compileExpressionRules(rhrs, rrs); err != nil { return nil, err } @@ -199,7 +201,7 @@ func AddCacheControlPolicyToRules(rules *config.HeaderRules, cacheControl config return rules } -func (hf *HeaderPropagation) getAllRules() ([]*config.RequestHeaderRule, []*config.ResponseHeaderRule) { +func (hf *HeaderPropagation) getAllRules() ([]*config.RequestHeaderRule, []*config.ResponseHeaderRule, []*config.RouterResponseHeaderRule) { rhrs := hf.rules.All.Request for _, subgraph := range hf.rules.Subgraphs { rhrs = append(rhrs, subgraph.Request...) @@ -210,7 +212,7 @@ func (hf *HeaderPropagation) getAllRules() ([]*config.RequestHeaderRule, []*conf rhrrs = append(rhrrs, subgraph.Response...) } - return rhrs, rhrrs + return rhrs, rhrrs, hf.rules.Router.Response } func (hf *HeaderPropagation) processRule(rule config.HeaderRule, index int) error { @@ -246,21 +248,35 @@ func (hf *HeaderPropagation) collectRuleMatchers(rhrs []*config.RequestHeaderRul return nil } -func (hf *HeaderPropagation) compileExpressionRules(rules []*config.RequestHeaderRule) error { +func (hf *HeaderPropagation) compileExpressionRules(subgraphRequestRules []*config.RequestHeaderRule, routerRequestRules []*config.RouterResponseHeaderRule) error { manager := expr.CreateNewExprManager() - for _, rule := range rules { - if rule.Expression == "" { - continue + for _, rule := range subgraphRequestRules { + if err := processExpression(rule.Expression, hf.compiledRequestRules, manager); err != nil { + return fmt.Errorf("error compiling header %s: %w", rule.Name, err) } - if _, ok := hf.compiledRules[rule.Expression]; ok { - continue - } - program, err := manager.CompileExpression(rule.Expression, reflect.String) - if err != nil { - return fmt.Errorf("error compiling expression %s for header rule %s: %w", rule.Expression, rule.Name, err) + } + + for _, rule := range routerRequestRules { + if err := processExpression(rule.Expression, hf.compiledRouterResponseRules, manager); err != nil { + return fmt.Errorf("error compiling header %s: %w", rule.Name, err) } - hf.compiledRules[rule.Expression] = program } + + return nil +} + +func processExpression(expression string, hf map[string]*vm.Program, manager *expr.Manager) error { + if expression == "" { + return nil + } + if _, ok := hf[expression]; ok { + return nil + } + program, err := manager.CompileExpression(expression, reflect.String) + if err != nil { + return fmt.Errorf("error compiling expression %s for header rule: %w", expression, err) + } + hf[expression] = program return nil } @@ -602,13 +618,28 @@ func (h *HeaderPropagation) getRequestRuleExpressionValue(rule *config.RequestHe if reqCtx == nil { return "", fmt.Errorf("context cannot be nil") } - program, ok := h.compiledRules[rule.Expression] + program, ok := h.compiledRequestRules[rule.Expression] + if !ok { + return "", fmt.Errorf("expression %s not found in compiled rules for header rule %s", rule.Expression, rule.Name) + } + value, err = expr.ResolveStringExpression(program, reqCtx.expressionContext) + if err != nil { + return "", fmt.Errorf("unable to resolve expression %q for header rule %s: %w", rule.Expression, rule.Name, err) + } + return +} + +func (h *HeaderPropagation) getRouterResponseRuleExpressionValue(rule *config.RouterResponseHeaderRule, reqCtx *requestContext) (value string, err error) { + if reqCtx == nil { + return "", fmt.Errorf("context cannot be nil") + } + program, ok := h.compiledRouterResponseRules[rule.Expression] if !ok { return "", fmt.Errorf("expression %s not found in compiled rules for header rule %s", rule.Expression, rule.Name) } value, err = expr.ResolveStringExpression(program, reqCtx.expressionContext) if err != nil { - return "", fmt.Errorf("unable to resolve expression %q for header rule %s: %s", rule.Expression, rule.Name, err.Error()) + return "", fmt.Errorf("unable to resolve expression %q for header rule %s: %w", rule.Expression, rule.Name, err) } return } @@ -672,6 +703,24 @@ func createMostRestrictivePolicy(policies []*cachedirective.Object) (*cachedirec return &result, cacheControlHeader } +// ApplyRouterResponseHeaderRules applies router response header rules to the response writer +func (h *HeaderPropagation) ApplyRouterResponseHeaderRules(w http.ResponseWriter, reqCtx *requestContext) error { + for _, rule := range h.rules.Router.Response { + if rule.Expression == "" { + continue + } + value, err := h.getRouterResponseRuleExpressionValue(rule, reqCtx) + if err != nil { + return fmt.Errorf("failed to evaluate router response header expression for %s: %w", rule.Name, err) + } + if value != "" { + w.Header().Set(rule.Name, value) + } + } + + return nil +} + // SubgraphRules returns the list of header rules for the subgraph with the given name func SubgraphRules(rules *config.HeaderRules, subgraphName string) []*config.RequestHeaderRule { if rules == nil { diff --git a/router/core/header_rule_engine_test.go b/router/core/header_rule_engine_test.go index 33ebebd68d..e309d79084 100644 --- a/router/core/header_rule_engine_test.go +++ b/router/core/header_rule_engine_test.go @@ -1056,3 +1056,199 @@ func TestExpression(t *testing.T) { assert.Equal(t, "Other-Value", updatedClientReq.Header.Get("X-Test-Header")) }) } + +func TestRouterResponseHeaderRules(t *testing.T) { + t.Run("Should set router response header with static expression", func(t *testing.T) { + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Client-Header", + Expression: "\"static-value\"", + }, + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, ht) + + rr := httptest.NewRecorder() + + reqCtx := &requestContext{ + logger: zap.NewNop(), + responseWriter: rr, + operation: &operationContext{}, + subgraphResolver: NewSubgraphResolver([]Subgraph{}), + } + clientCtx := withRequestContext(context.Background(), reqCtx) + clientReq, err := http.NewRequestWithContext(clientCtx, "POST", "http://localhost", nil) + require.NoError(t, err) + reqCtx.expressionContext = expr.Context{Request: expr.LoadRequest(clientReq)} + reqCtx.request = clientReq + + err = ht.ApplyRouterResponseHeaderRules(rr, reqCtx) + assert.NoError(t, err) + + assert.Equal(t, "static-value", rr.Header().Get("X-Client-Header")) + }) + + t.Run("Should set router response header with expression from request header", func(t *testing.T) { + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Client-ID", + Expression: "request.header.Get(\"X-User-ID\")", + }, + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, ht) + + rr := httptest.NewRecorder() + + reqCtx := &requestContext{ + logger: zap.NewNop(), + responseWriter: rr, + operation: &operationContext{}, + subgraphResolver: NewSubgraphResolver([]Subgraph{}), + } + clientCtx := withRequestContext(context.Background(), reqCtx) + clientReq, err := http.NewRequestWithContext(clientCtx, "POST", "http://localhost", nil) + require.NoError(t, err) + clientReq.Header.Set("X-User-ID", "user-123") + reqCtx.expressionContext = expr.Context{Request: expr.LoadRequest(clientReq)} + reqCtx.request = clientReq + + err = ht.ApplyRouterResponseHeaderRules(rr, reqCtx) + assert.NoError(t, err) + + assert.Equal(t, "user-123", rr.Header().Get("X-Client-ID")) + }) + + t.Run("Should set multiple router response headers", func(t *testing.T) { + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Client-Header-1", + Expression: "\"value-1\"", + }, + { + Name: "X-Client-Header-2", + Expression: "\"value-2\"", + }, + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, ht) + + rr := httptest.NewRecorder() + + reqCtx := &requestContext{ + logger: zap.NewNop(), + responseWriter: rr, + operation: &operationContext{}, + subgraphResolver: NewSubgraphResolver([]Subgraph{}), + } + clientCtx := withRequestContext(context.Background(), reqCtx) + clientReq, err := http.NewRequestWithContext(clientCtx, "POST", "http://localhost", nil) + require.NoError(t, err) + reqCtx.expressionContext = expr.Context{Request: expr.LoadRequest(clientReq)} + reqCtx.request = clientReq + + err = ht.ApplyRouterResponseHeaderRules(rr, reqCtx) + assert.NoError(t, err) + + assert.Equal(t, "value-1", rr.Header().Get("X-Client-Header-1")) + assert.Equal(t, "value-2", rr.Header().Get("X-Client-Header-2")) + }) + + t.Run("Should set router response header with complex expression", func(t *testing.T) { + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Combined-Header", + Expression: "request.header.Get(\"X-User-ID\") + \"-\" + request.header.Get(\"X-Session-ID\")", + }, + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, ht) + + rr := httptest.NewRecorder() + + reqCtx := &requestContext{ + logger: zap.NewNop(), + responseWriter: rr, + operation: &operationContext{}, + subgraphResolver: NewSubgraphResolver([]Subgraph{}), + } + clientCtx := withRequestContext(context.Background(), reqCtx) + clientReq, err := http.NewRequestWithContext(clientCtx, "POST", "http://localhost", nil) + require.NoError(t, err) + clientReq.Header.Set("X-User-ID", "user-123") + clientReq.Header.Set("X-Session-ID", "session-456") + reqCtx.expressionContext = expr.Context{Request: expr.LoadRequest(clientReq)} + reqCtx.request = clientReq + + err = ht.ApplyRouterResponseHeaderRules(rr, reqCtx) + assert.NoError(t, err) + + assert.Equal(t, "user-123-session-456", rr.Header().Get("X-Combined-Header")) + }) + + t.Run("Should return error when router response header expression is invalid", func(t *testing.T) { + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Invalid", + Expression: "invalid expression syntax", + }, + }, + }, + }) + assert.Nil(t, ht) + assert.Error(t, err) + }) + + t.Run("Should ignore router response header rules which resolve to \"\"", func(t *testing.T) { + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Router: config.RouterHeaderRules{ + Response: []*config.RouterResponseHeaderRule{ + { + Name: "X-Client-ID", + Expression: "\"\"", + }, + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, ht) + + rr := httptest.NewRecorder() + + reqCtx := &requestContext{ + logger: zap.NewNop(), + responseWriter: rr, + operation: &operationContext{}, + subgraphResolver: NewSubgraphResolver([]Subgraph{}), + } + clientCtx := withRequestContext(context.Background(), reqCtx) + clientReq, err := http.NewRequestWithContext(clientCtx, "POST", "http://localhost", nil) + require.NoError(t, err) + reqCtx.expressionContext = expr.Context{Request: expr.LoadRequest(clientReq)} + reqCtx.request = clientReq + + err = ht.ApplyRouterResponseHeaderRules(rr, reqCtx) + assert.NoError(t, err) + + // Should not set the header since the expression resolves to empty string + assert.Empty(t, rr.Header().Get("X-Client-ID")) + }) +} diff --git a/router/core/router.go b/router/core/router.go index ad4b77cc33..5723a54e3f 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -312,6 +312,7 @@ func NewRouter(opts ...Option) (*Router, error) { if err != nil { return nil, err } + r.headerPropagation = hr if hr.HasRequestRules() { r.preOriginHandlers = append(r.preOriginHandlers, hr.OnOriginRequest) diff --git a/router/core/router_config.go b/router/core/router_config.go index 319216a18a..9aecaa9190 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -97,6 +97,7 @@ type Config struct { preOriginHandlers []TransportPreHandler postOriginHandlers []TransportPostHandler headerRules *config.HeaderRules + headerPropagation *HeaderPropagation subgraphTransportOptions *SubgraphTransportOptions subgraphCircuitBreakerOptions *SubgraphCircuitBreakerOptions graphqlMetricsConfig *GraphQLMetricsConfig diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 8eb71bc5f1..654583c080 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -263,6 +263,12 @@ type HeaderRules struct { All *GlobalHeaderRule `yaml:"all,omitempty"` Subgraphs map[string]*GlobalHeaderRule `yaml:"subgraphs,omitempty"` CookieWhitelist []string `yaml:"cookie_whitelist,omitempty"` + Router RouterHeaderRules `yaml:"router,omitempty"` +} + +type RouterHeaderRules struct { + // All is a set of rules that apply to all response + Response []*RouterResponseHeaderRule `yaml:"response,omitempty"` } type GlobalHeaderRule struct { @@ -329,6 +335,12 @@ const ( ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl ResponseHeaderRuleAlgorithm = "most_restrictive_cache_control" ) +type RouterResponseHeaderRule struct { + // Set header options + Name string `yaml:"name"` + Expression string `yaml:"expression"` +} + type ResponseHeaderRule struct { // Operation describes the header operation to perform e.g. "propagate" Operation HeaderRuleOperation `yaml:"op"` diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index a531fa4af3..6366fec377 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1665,6 +1665,18 @@ "items": { "type": "string" } + }, + "router": { + "type": "object", + "properties": { + "response": { + "type": "array", + "description": "A list of header rules to apply to router responses.", + "items": { + "$ref": "#/$defs/router_response_header_rule" + } + } + } } } }, @@ -3536,6 +3548,23 @@ }, "required": ["op", "algorithm"] }, + "router_response_header_rule": { + "type": "object", + "description": "The configuration for router response headers. This is used to set headers in response from the router to clients.", + "additionalProperties": false, + "properties": { + "name": { + "type": "string", + "description": "The name of the header to set.", + "examples": ["X-Custom-Header"] + }, + "expression": { + "type": "string", + "description": "The template expression to evaluate for the header value. The expression must return a string value." + } + }, + "required": ["name", "expression"] + }, "set_header_rule": { "type": "object", "description": "The configuration for setting headers. This is used to set specific headers in requests or responses.", diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index 13628a925a..b089b39eac 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -268,6 +268,12 @@ headers: cookie_whitelist: - 'cookie1' - 'cookie2' + router: + response: + - name: 'X-Client-Version' + expression: "request.header.Get('User-Agent')" + - name: 'X-Request-ID' + expression: "request.header.Get('X-Request-ID') ?? 'default-request-id'" # Authentication and Authorization # See https://cosmo-docs.wundergraph.com/router/authentication-and-authorization for more information diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index b4ddad685e..002aa1f512 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -145,7 +145,10 @@ "Headers": { "All": null, "Subgraphs": null, - "CookieWhitelist": null + "CookieWhitelist": null, + "Router": { + "Response": null + } }, "TrafficShaping": { "All": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index d4707aa1a8..e54c0d50b7 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -310,7 +310,19 @@ "CookieWhitelist": [ "cookie1", "cookie2" - ] + ], + "Router": { + "Response": [ + { + "Name": "X-Client-Version", + "Expression": "request.header.Get('User-Agent')" + }, + { + "Name": "X-Request-ID", + "Expression": "request.header.Get('X-Request-ID') ?? 'default-request-id'" + } + ] + } }, "TrafficShaping": { "All": {