diff --git a/router-tests/header_propagation_test.go b/router-tests/header_propagation_test.go new file mode 100644 index 0000000000..3f90d81110 --- /dev/null +++ b/router-tests/header_propagation_test.go @@ -0,0 +1,552 @@ +package integration + +import ( + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestHeaderPropagation(t *testing.T) { + t.Parallel() + + const ( + customHeader = "X-Custom-Header" + employeeVal = "employee-value" + hobbyVal = "hobby-value" + ) + + const queryEmployeeWithHobby = `{ + employee(id: 1) { + id + hobbies { + ... on Gaming { + name + } + } + } + }` + + const queryEmployeeWithNoHobby = `{ + employee(id: 1) { + id + } + }` + + getRule := func(alg config.ResponseHeaderRuleAlgorithm, named, defaultVal string) *config.ResponseHeaderRule { + rule := &config.ResponseHeaderRule{ + Operation: config.HeaderRuleOperationPropagate, + Algorithm: alg, + } + if named != "" { + rule.Named = named + } + if defaultVal != "" { + rule.Default = defaultVal + } + return rule + } + + global := func(alg config.ResponseHeaderRuleAlgorithm, named, defaultVal string) []core.Option { + return []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Response: []*config.ResponseHeaderRule{ + getRule(alg, named, defaultVal), + }, + }, + }), + } + } + + partial := func(alg config.ResponseHeaderRuleAlgorithm, named, defaultVal string) []core.Option { + return []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ + "employees": { + Response: []*config.ResponseHeaderRule{ + getRule(alg, named, defaultVal), + }, + }, + }, + }), + } + } + + local := func(alg config.ResponseHeaderRuleAlgorithm, named, defaultValA, defaultValB string) []core.Option { + return []core.Option{ + core.WithHeaderRules(config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ + "employees": { + Response: []*config.ResponseHeaderRule{ + getRule(alg, named, defaultValA), + }, + }, + "hobbies": { + Response: []*config.ResponseHeaderRule{ + getRule(alg, named, defaultValB), + }, + }, + }, + }), + } + } + + setSubgraphPropagateHeader := func(header, valA, valB string) testenv.SubgraphsConfig { + return testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(header, valA) + handler.ServeHTTP(w, r) + }) + }, + }, + Hobbies: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(header, valB) + handler.ServeHTTP(w, r) + }) + }, + }, + } + } + + subgraphsWithExpiresHeader := testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expiresTime := time.Now().UTC().Add(10 * time.Minute).Format(http.TimeFormat) + w.Header().Set("Expires", expiresTime) + handler.ServeHTTP(w, r) + }) + }, + }, + Hobbies: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expiresTime := time.Now().UTC().Add(5 * time.Minute).Format(http.TimeFormat) + w.Header().Set("Expires", expiresTime) // Earlier, more restrictive + handler.ServeHTTP(w, r) + }) + }, + }, + } + + cacheOptions := func(cacheControlEmployees, cacheControlHobbies string) testenv.SubgraphsConfig { + return setSubgraphPropagateHeader("Cache-Control", cacheControlEmployees, cacheControlHobbies) + } + + var ( + subgraphsPropagateCustomHeader = setSubgraphPropagateHeader(customHeader, employeeVal, hobbyVal) + ) + + t.Run(" no propagate", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, "", ch) + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("LastWriteWins", func(t *testing.T) { + t.Run("global last write wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmLastWrite, customHeader, ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, hobbyVal, ch) + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("local last write wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: local(config.ResponseHeaderRuleAlgorithmLastWrite, customHeader, "", ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, hobbyVal, ch) + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("partial last write wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: partial(config.ResponseHeaderRuleAlgorithmLastWrite, customHeader, ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, employeeVal, ch) + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + }) + + // Test for the First Write Wins Algorithm + t.Run("FirstWriteWins", func(t *testing.T) { + t.Run("global first write wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmFirstWrite, customHeader, ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, employeeVal, ch) // First write is "employee-value" + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("local first write wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: local(config.ResponseHeaderRuleAlgorithmFirstWrite, customHeader, "", ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, employeeVal, ch) // First write is "employee-value" + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("partial first write wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: partial(config.ResponseHeaderRuleAlgorithmFirstWrite, customHeader, ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, employeeVal, ch) // First write is "employee-value" + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + }) + + // Test for the Append Algorithm + t.Run("AppendHeaders", func(t *testing.T) { + t.Run("global append headers", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmAppend, customHeader, ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, "employee-value,hobby-value", ch) // Headers are appended + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("local append headers", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: local(config.ResponseHeaderRuleAlgorithmAppend, customHeader, "", ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, "employee-value,hobby-value", ch) // Headers are appended + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("partial append headers", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: partial(config.ResponseHeaderRuleAlgorithmAppend, customHeader, ""), + Subgraphs: subgraphsPropagateCustomHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + ch := strings.Join(res.Response.Header.Values(customHeader), ",") + require.Equal(t, employeeVal, ch) // Only employee's header is appended + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + }) + + t.Run("MostRestrictiveCacheControl", func(t *testing.T) { + // Global test: All subgraphs' responses are considered and most restrictive cache directive wins + t.Run("global most restrictive cache control", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", ""), + Subgraphs: cacheOptions("max-age=120", "max-age=60"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "max-age=60", cc) // Most restrictive wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + // Local test: Cache control rules are applied per subgraph (employees and hobbies) + t.Run("local most restrictive cache control", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: local(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", "", ""), + Subgraphs: cacheOptions("max-age=120", "max-age=60"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "max-age=60", cc) // Most restrictive wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + // Partial test: Only one subgraph's response is considered (e.g., employees) + t.Run("partial most restrictive cache control", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: partial(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", ""), + Subgraphs: cacheOptions("max-age=120", "max-age=60"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "max-age=120", cc) // Only employee subgraph is considered + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + // Test case for no-store being the most restrictive + t.Run("global no-store wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", ""), + Subgraphs: cacheOptions("no-store", "max-age=300"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "no-store", cc) // no-store wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + // Test case for no-cache being more restrictive than max-age + t.Run("global no-cache wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", ""), + Subgraphs: cacheOptions("no-cache", "max-age=300"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "no-cache", cc) // no-cache wins over max-age + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("global no-cache wins against no value", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", ""), + Subgraphs: cacheOptions("no-cache", ""), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "no-cache", cc) // no-cache wins over max-age + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + // Test case for max-age: shortest max-age wins + t.Run("global shortest max-age wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", ""), + Subgraphs: cacheOptions("max-age=600", "max-age=300"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "max-age=300", cc) // Shorter max-age wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + // Test case for Expires header: earliest expiration wins + t.Run("global earliest Expires wins", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", ""), + Subgraphs: subgraphsWithExpiresHeader, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + expires := res.Response.Header.Get("Expires") + require.NotEmpty(t, expires) + + // Parse the Expires header and convert both times to UTC for comparison + parsedExpires, err := http.ParseTime(expires) + require.NoError(t, err) + + now := time.Now().Add(5 * time.Minute) // Example expiration + require.WithinDuration(t, now, parsedExpires, 20*time.Second) // Ensure expiration is within expected range + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("default value adds max age", func(t *testing.T) { + t.Parallel() + + t.Run("global default age sets for all requests", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", "max-age=300"), + Subgraphs: cacheOptions("", "max-age=600"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "max-age=300", cc) // Shorter max-age wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("global no-cache sets for all requests", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", "no-cache"), + Subgraphs: cacheOptions("", "max-age=600"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "no-cache", cc) // Shorter max-age wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("global default age sets for all requests", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", "no-cache"), + Subgraphs: cacheOptions("max-age=60", "max-age=300"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "no-cache", cc) // Shorter max-age wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("allows subgraph to override default", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", "max-age=300"), + Subgraphs: cacheOptions("max-age=60", "max-age=180"), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "max-age=60", cc) // Shorter max-age wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("partial default age sets for requests with information", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: local(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", "", "max-age=300"), + Subgraphs: cacheOptions("", ""), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "max-age=300", cc) // Shorter max-age wins + require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body) + }) + }) + + t.Run("partial default age doesn't set for unassociated requests", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: local(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", "", "max-age=300"), + Subgraphs: cacheOptions("", ""), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryEmployeeWithNoHobby, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "", cc) + require.Equal(t, `{"data":{"employee":{"id":1}}}`, res.Body) + }) + }) + + t.Run("no-cache is set for all mutations", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + RouterOptions: global(config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl, "", "max-age=300"), + Subgraphs: cacheOptions("", ""), + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + cc := res.Response.Header.Get("Cache-Control") + require.Equal(t, "no-cache", cc) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body) + }) + }) + }) + }) +} diff --git a/router-tests/headers_test.go b/router-tests/headers_test.go index 1d718acd18..bc58be58ee 100644 --- a/router-tests/headers_test.go +++ b/router-tests/headers_test.go @@ -27,17 +27,17 @@ func TestForwardHeaders(t *testing.T) { ) headerRules := config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: config.HeaderRuleOperationPropagate, Named: headerNameInGlobalRule, }, }, }, - Subgraphs: map[string]config.GlobalHeaderRule{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "test1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: config.HeaderRuleOperationPropagate, Matching: "(?i)^bar.*", @@ -246,8 +246,8 @@ func TestForwardRenamedHeaders(t *testing.T) { ) headerRules := config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: config.HeaderRuleOperationPropagate, Named: headerNameInGlobalRule, @@ -255,9 +255,9 @@ func TestForwardRenamedHeaders(t *testing.T) { }, }, }, - Subgraphs: map[string]config.GlobalHeaderRule{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "test1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: config.HeaderRuleOperationPropagate, Matching: "(?i)^bar.*", diff --git a/router-tests/singleflight_test.go b/router-tests/singleflight_test.go index be398864b6..309a4eab7c 100644 --- a/router-tests/singleflight_test.go +++ b/router-tests/singleflight_test.go @@ -198,8 +198,8 @@ func TestSingleFlightDifferentHeaders(t *testing.T) { }, RouterOptions: []core.Option{ core.WithHeaderRules(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Named: "Authorization", Operation: config.HeaderRuleOperationPropagate, @@ -245,8 +245,8 @@ func TestSingleFlightSameHeaders(t *testing.T) { }, RouterOptions: []core.Option{ core.WithHeaderRules(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Named: "Authorization", Operation: config.HeaderRuleOperationPropagate, diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 9a4b4c9ea8..f1a8419358 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -10,7 +10,6 @@ import ( "encoding/json" "errors" "fmt" - "go.uber.org/zap/zaptest/observer" "io" "log" "math/rand" @@ -25,6 +24,8 @@ import ( "testing" "time" + "go.uber.org/zap/zaptest/observer" + "github.com/wundergraph/cosmo/router/pkg/controlplane/configpoller" "github.com/nats-io/nats.go/jetstream" diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index 9ea9de80ac..5a8d7bc41f 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -516,8 +516,8 @@ func TestWebSockets(t *testing.T) { t.Run("subscription with header propagation", func(t *testing.T) { t.Parallel() headerRules := config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: config.HeaderRuleOperationPropagate, Named: "Authorization", @@ -652,8 +652,8 @@ func TestWebSockets(t *testing.T) { t.Run("empty allow lists should allow all headers and query args", func(t *testing.T) { t.Parallel() headerRules := config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: config.HeaderRuleOperationPropagate, Named: "Authorization", @@ -795,8 +795,8 @@ func TestWebSockets(t *testing.T) { t.Run("subscription with header propagation sse subgraph post", func(t *testing.T) { t.Parallel() headerRules := config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: config.HeaderRuleOperationPropagate, Named: "Authorization", @@ -910,8 +910,8 @@ func TestWebSockets(t *testing.T) { t.Run("subscription with header propagation sse subgraph get", func(t *testing.T) { t.Parallel() headerRules := config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: config.HeaderRuleOperationPropagate, Named: "Authorization", diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index 2c779950c6..6fabfaccd9 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -133,7 +133,7 @@ func (l *Loader) LoadInternedString(engineConfig *nodev1.EngineConfiguration, st type RouterEngineConfiguration struct { Execution config.EngineExecutionConfiguration - Headers config.HeaderRules + Headers *config.HeaderRules Events config.EventsConfiguration SubgraphErrorPropagation config.SubgraphErrorPropagationConfiguration } @@ -314,7 +314,7 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod } } - dataSourceRules := FetchURLRules(&routerEngineConfig.Headers, subgraphs, subscriptionUrl) + dataSourceRules := FetchURLRules(routerEngineConfig.Headers, subgraphs, subscriptionUrl) forwardedClientHeaders, forwardedClientRegexps, err := PropagatedHeaders(dataSourceRules) if err != nil { return nil, fmt.Errorf("error parsing header rules for data source %s: %w", in.Id, err) diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 5786cce7a0..62867eabfe 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -611,6 +611,7 @@ func (s *graphServer) buildGraphMux(ctx context.Context, EnableExecutionPlanCacheResponseHeader: s.engineExecutionConfiguration.EnableExecutionPlanCacheResponseHeader, EnablePersistedOperationCacheResponseHeader: s.engineExecutionConfiguration.Debug.EnablePersistedOperationsCacheResponseHeader, EnableNormalizationCacheResponseHeader: s.engineExecutionConfiguration.Debug.EnableNormalizationCacheResponseHeader, + EnableResponseHeaderPropagation: s.headerRules != nil, WebSocketStats: s.websocketStats, TracerProvider: s.tracerProvider, Authorizer: NewCosmoAuthorizer(authorizerOptions), diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index 32b35a0d07..01387a54c7 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -80,6 +80,7 @@ type HandlerOptions struct { EnableExecutionPlanCacheResponseHeader bool EnablePersistedOperationCacheResponseHeader bool EnableNormalizationCacheResponseHeader bool + EnableResponseHeaderPropagation bool WebSocketStats WebSocketsStatistics TracerProvider trace.TracerProvider Authorizer *CosmoAuthorizer @@ -96,6 +97,7 @@ func NewGraphQLHandler(opts HandlerOptions) *GraphQLHandler { enableExecutionPlanCacheResponseHeader: opts.EnableExecutionPlanCacheResponseHeader, enablePersistedOperationCacheResponseHeader: opts.EnablePersistedOperationCacheResponseHeader, enableNormalizationCacheResponseHeader: opts.EnableNormalizationCacheResponseHeader, + enableResponseHeaderPropagation: opts.EnableResponseHeaderPropagation, websocketStats: opts.WebSocketStats, tracer: opts.TracerProvider.Tracer( "wundergraph/cosmo/router/graphql_handler", @@ -135,6 +137,7 @@ type GraphQLHandler struct { enableExecutionPlanCacheResponseHeader bool enablePersistedOperationCacheResponseHeader bool enableNormalizationCacheResponseHeader bool + enableResponseHeaderPropagation bool } func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -182,7 +185,10 @@ func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case *plan.SynchronousResponsePlan: w.Header().Set("Content-Type", "application/json") h.setDebugCacheHeaders(w, operationCtx) - resp, err := h.executor.Resolver.ResolveGraphQLResponse(ctx, p.Response, nil, w) + if h.enableResponseHeaderPropagation { + ctx = WithResponseHeaderPropagation(ctx) + } + resp, err := h.executor.Resolver.ResolveGraphQLResponse(ctx, p.Response, nil, HeaderPropagationWriter(w, ctx.Context())) if err != nil { requestLogger.Error("unable to resolve response", zap.Error(err)) trackResponseError(ctx.Context(), err) diff --git a/router/core/header_rule_engine.go b/router/core/header_rule_engine.go index 67901a83b8..af790c29db 100644 --- a/router/core/header_rule_engine.go +++ b/router/core/header_rule_engine.go @@ -1,18 +1,27 @@ package core import ( + "context" "fmt" + "github.com/wundergraph/cosmo/router/pkg/otel" + rtrace "github.com/wundergraph/cosmo/router/pkg/trace" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "io" "net/http" "regexp" "slices" + "sync" + "time" - "github.com/wundergraph/cosmo/router/pkg/config" - + cachedirective "github.com/pquerna/cachecontrol/cacheobject" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" ) var ( - _ EnginePreOriginHandler = (*HeaderRuleEngine)(nil) + _ EnginePreOriginHandler = (*HeaderPropagation)(nil) ignoredHeaders = []string{ "Alt-Svc", "Connection", @@ -37,158 +46,480 @@ var ( } ) -// HeaderRuleEngine is a pre-origin handler that can be used to propagate and +type responseHeaderPropagationKey struct{} + +type responseHeaderPropagation struct { + header http.Header + m *sync.Mutex + previousCacheControl *cachedirective.Object +} + +func WithResponseHeaderPropagation(ctx *resolve.Context) *resolve.Context { + return ctx.WithContext(context.WithValue(ctx.Context(), responseHeaderPropagationKey{}, &responseHeaderPropagation{ + header: make(http.Header), + m: &sync.Mutex{}, + })) +} + +func getResponseHeaderPropagation(ctx context.Context) *responseHeaderPropagation { + v := ctx.Value(responseHeaderPropagationKey{}) + if v == nil { + return nil + } + return v.(*responseHeaderPropagation) +} + +func HeaderPropagationWriter(w http.ResponseWriter, ctx context.Context) io.Writer { + propagation := getResponseHeaderPropagation(ctx) + if propagation == nil { + return w + } + return &headerPropagationWriter{ + writer: w, + headerPropagation: propagation, + propagateHeaders: true, + } +} + +type headerPropagationWriter struct { + writer http.ResponseWriter + headerPropagation *responseHeaderPropagation + propagateHeaders bool +} + +func (h *headerPropagationWriter) Write(p []byte) (n int, err error) { + if h.propagateHeaders { + for k, v := range h.headerPropagation.header { + for _, el := range v { + h.writer.Header().Add(k, el) + } + } + h.propagateHeaders = false + } + return h.writer.Write(p) +} + +// HeaderPropagation is a pre-origin handler that can be used to propagate and // manipulate headers from the client request to the upstream -type HeaderRuleEngine struct { - regex map[string]regexp.Regexp - rules config.HeaderRules +type HeaderPropagation struct { + regex map[string]*regexp.Regexp + rules *config.HeaderRules + hasRequestRules bool + hasResponseRules bool } -func NewHeaderTransformer(rules config.HeaderRules) (*HeaderRuleEngine, error) { - hf := HeaderRuleEngine{ +func NewHeaderPropagation(rules *config.HeaderRules) (*HeaderPropagation, error) { + if rules == nil { + return nil, nil + } + + if rules.All == nil { + rules.All = &config.GlobalHeaderRule{} + } + if rules.Subgraphs == nil { + rules.Subgraphs = make(map[string]*config.GlobalHeaderRule) + } + + hf := HeaderPropagation{ rules: rules, - regex: map[string]regexp.Regexp{}, + regex: map[string]*regexp.Regexp{}, } - var rhrs []config.RequestHeaderRule + rhrs, rhrrs := hf.getAllRules() + hf.hasRequestRules = len(rhrs) > 0 + hf.hasResponseRules = len(rhrrs) > 0 - rhrs = append(rhrs, rules.All.Request...) + if err := hf.collectRuleMatchers(rhrs, rhrrs); err != nil { + return nil, err + } - for _, subgraph := range rules.Subgraphs { + return &hf, nil +} + +func (hf *HeaderPropagation) getAllRules() ([]*config.RequestHeaderRule, []*config.ResponseHeaderRule) { + rhrs := hf.rules.All.Request + for _, subgraph := range hf.rules.Subgraphs { rhrs = append(rhrs, subgraph.Request...) } - for i, rule := range rhrs { - switch rule.Operation { - case config.HeaderRuleOperationPropagate: - if rule.Matching != "" { - regex, err := regexp.Compile(rule.Matching) - if err != nil { - return nil, fmt.Errorf("invalid regex '%s' for header rule %d: %w", rule.Matching, i, err) - } - hf.regex[rule.Matching] = *regex + rhrrs := hf.rules.All.Response + for _, subgraph := range hf.rules.Subgraphs { + rhrrs = append(rhrrs, subgraph.Response...) + } + + return rhrs, rhrrs +} + +func (hf *HeaderPropagation) processRule(rule config.HeaderRule, index int) error { + switch rule.GetOperation() { + case config.HeaderRuleOperationPropagate: + if rule.GetMatching() != "" { + regex, err := regexp.Compile(rule.GetMatching()) + if err != nil { + return fmt.Errorf("invalid regex '%s' for header rule %d: %w", rule.GetMatching(), index, err) } - default: - return nil, fmt.Errorf("unhandled operation '%s' for header rule %+v", rule.Operation, rule) + hf.regex[rule.GetMatching()] = regex + } + default: + return fmt.Errorf("unhandled operation '%s' for header rule %+v", rule.GetOperation(), rule) + } + return nil +} + +func (hf *HeaderPropagation) collectRuleMatchers(rhrs []*config.RequestHeaderRule, rhrrs []*config.ResponseHeaderRule) error { + for i, rule := range rhrs { + if err := hf.processRule(rule, i); err != nil { + return err } } - return &hf, nil + for i, rule := range rhrrs { + if err := hf.processRule(rule, i); err != nil { + return err + } + } + + return nil +} + +func (h *HeaderPropagation) HasRequestRules() bool { + if h == nil { + return false + } + return h.hasRequestRules +} + +func (h *HeaderPropagation) HasResponseRules() bool { + if h == nil { + return false + } + return h.hasResponseRules } -func (h HeaderRuleEngine) OnOriginRequest(request *http.Request, ctx RequestContext) (*http.Request, *http.Response) { - requestRules := h.rules.All.Request +func (h *HeaderPropagation) OnOriginRequest(request *http.Request, ctx RequestContext) (*http.Request, *http.Response) { + for _, rule := range h.rules.All.Request { + h.applyRequestRule(ctx, request, rule) + } subgraph := ctx.ActiveSubgraph(request) if subgraph != nil { if subgraphRules, ok := h.rules.Subgraphs[subgraph.Name]; ok { - requestRules = append(requestRules, subgraphRules.Request...) + for _, rule := range subgraphRules.Request { + h.applyRequestRule(ctx, request, rule) + } } } - for _, rule := range requestRules { - if rule.Operation == config.HeaderRuleOperationPropagate { - - /** - * Rename the header before propagating and delete the original - */ + return request, nil +} - if rule.Rename != "" && rule.Named != "" { - // Ignore the rule when the target header is in the ignored list - if slices.Contains(ignoredHeaders, rule.Rename) { - continue - } +func (h *HeaderPropagation) OnOriginResponse(resp *http.Response, ctx RequestContext) *http.Response { + propagation := getResponseHeaderPropagation(resp.Request.Context()) + if propagation == nil { + return resp + } - value := ctx.Request().Header.Get(rule.Named) - if value != "" { - request.Header.Set(rule.Rename, ctx.Request().Header.Get(rule.Named)) - request.Header.Del(rule.Named) - continue - } else if rule.Default != "" { - request.Header.Set(rule.Rename, rule.Default) - request.Header.Del(rule.Named) - continue - } + for _, rule := range h.rules.All.Response { + h.applyResponseRule(propagation, resp, rule) + } - continue + subgraph := ctx.ActiveSubgraph(resp.Request) + if subgraph != nil { + if subgraphRules, ok := h.rules.Subgraphs[subgraph.Name]; ok { + for _, rule := range subgraphRules.Response { + h.applyResponseRule(propagation, resp, rule) } + } + } - /** - * Propagate the header as is - */ + return resp +} - if rule.Named != "" { - if slices.Contains(ignoredHeaders, rule.Named) { - continue - } +func (h *HeaderPropagation) applyResponseRule(propagation *responseHeaderPropagation, res *http.Response, rule *config.ResponseHeaderRule) { + if rule.Operation != config.HeaderRuleOperationPropagate { + return + } - value := ctx.Request().Header.Get(rule.Named) - if value != "" { - request.Header.Set(rule.Named, ctx.Request().Header.Get(rule.Named)) - } else if rule.Default != "" { - request.Header.Set(rule.Named, rule.Default) - } + if rule.Named != "" { + if slices.Contains(ignoredHeaders, rule.Named) { + return + } + + value := res.Header.Get(rule.Named) + if value != "" { + h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, value) + } else if rule.Default != "" { + h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, rule.Default) + } - continue + return + } else if rule.Matching != "" { + if regex, ok := h.regex[rule.Matching]; ok { + for name := range res.Header { + if regex.MatchString(name) { + if slices.Contains(ignoredHeaders, name) { + continue + } + h.applyResponseRuleKeyValue(res, propagation, rule, name, res.Header.Get(name)) + } } + } + } else if rule.Algorithm == config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl { + // Explicitly apply the CacheControl algorithm on the headers + h.applyResponseRuleKeyValue(res, propagation, rule, "", "") + } +} + +func (h *HeaderPropagation) applyResponseRuleKeyValue(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule, key, value string) { + switch rule.Algorithm { + case config.ResponseHeaderRuleAlgorithmFirstWrite: + propagation.m.Lock() + if val := propagation.header.Get(key); val == "" { + propagation.header.Set(key, value) + } + propagation.m.Unlock() + case config.ResponseHeaderRuleAlgorithmLastWrite: + propagation.m.Lock() + propagation.header.Set(key, value) + propagation.m.Unlock() + case config.ResponseHeaderRuleAlgorithmAppend: + propagation.m.Lock() + propagation.header.Add(key, value) + propagation.m.Unlock() + case config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl: + h.applyResponseRuleMostRestrictiveCacheControl(res, propagation, rule) + } +} + +func (h *HeaderPropagation) applyRequestRule(ctx RequestContext, request *http.Request, rule *config.RequestHeaderRule) { + if rule.Operation != config.HeaderRuleOperationPropagate { + return + } + + /** + * Rename the header before propagating and delete the original + */ + + if rule.Rename != "" && rule.Named != "" { + // Ignore the rule when the target header is in the ignored list + if slices.Contains(ignoredHeaders, rule.Rename) { + return + } - /** - * Matching based on regex - */ - - if regex, ok := h.regex[rule.Matching]; ok { - for name := range ctx.Request().Header { - // Headers are case-insensitive, but Go canonicalize them - // Issue: https://github.com/golang/go/issues/37834 - if regex.MatchString(name) { - - /** - * Rename the header before propagating and delete the original - */ - if rule.Rename != "" && rule.Named == "" { - - if slices.Contains(ignoredHeaders, rule.Rename) { - continue - } - - value := ctx.Request().Header.Get(name) - if value != "" { - request.Header.Set(rule.Rename, ctx.Request().Header.Get(name)) - request.Header.Del(name) - } else if rule.Default != "" { - request.Header.Set(rule.Rename, rule.Default) - request.Header.Del(name) - } - - continue - } - - /** - * Propagate the header as is - */ - if slices.Contains(ignoredHeaders, name) { - continue - } - request.Header.Set(name, ctx.Request().Header.Get(name)) + value := ctx.Request().Header.Get(rule.Named) + if value != "" { + request.Header.Set(rule.Rename, ctx.Request().Header.Get(rule.Named)) + request.Header.Del(rule.Named) + return + } else if rule.Default != "" { + request.Header.Set(rule.Rename, rule.Default) + request.Header.Del(rule.Named) + return + } + + return + } + + /** + * Propagate the header as is + */ + + if rule.Named != "" { + if slices.Contains(ignoredHeaders, rule.Named) { + return + } + + value := ctx.Request().Header.Get(rule.Named) + if value != "" { + request.Header.Set(rule.Named, ctx.Request().Header.Get(rule.Named)) + } else if rule.Default != "" { + request.Header.Set(rule.Named, rule.Default) + } + + return + } + + /** + * Matching based on regex + */ + + if regex, ok := h.regex[rule.Matching]; ok { + for name := range ctx.Request().Header { + // Headers are case-insensitive, but Go canonicalize them + // Issue: https://github.com/golang/go/issues/37834 + if regex.MatchString(name) { + + /** + * Rename the header before propagating and delete the original + */ + if rule.Rename != "" && rule.Named == "" { + + if slices.Contains(ignoredHeaders, rule.Rename) { + continue } + + value := ctx.Request().Header.Get(name) + if value != "" { + request.Header.Set(rule.Rename, ctx.Request().Header.Get(name)) + request.Header.Del(name) + } else if rule.Default != "" { + request.Header.Set(rule.Rename, rule.Default) + request.Header.Del(name) + } + + continue } + + /** + * Propagate the header as is + */ + if slices.Contains(ignoredHeaders, name) { + continue + } + request.Header.Set(name, ctx.Request().Header.Get(name)) } } } +} - return request, nil +func (h *HeaderPropagation) applyResponseRuleMostRestrictiveCacheControl(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule) { + cacheControlKey := "Cache-Control" + + ctx := res.Request.Context() + tracer := rtrace.TracerFromContext(ctx) + commonAttributes := []attribute.KeyValue{ + otel.WgOperationProtocol.String(OperationProtocolHTTP.String()), + } + + _, span := tracer.Start(ctx, "HeaderPropagation - RestrictiveCacheControl", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes(commonAttributes...), + ) + + // Set no-cache for all mutations, to ensure that requests to mutate data always work as expected (without returning cached data) + if resolve.SingleFlightDisallowed(ctx) { + var noCache = "no-cache" + propagation.header.Set(cacheControlKey, noCache) + return + } + + reqDir, _ := cachedirective.ParseRequestCacheControl(res.Request.Header.Get(cacheControlKey)) + resDir, _ := cachedirective.ParseResponseCacheControl(res.Header.Get(cacheControlKey)) + expiresHeader, _ := http.ParseTime(res.Header.Get("Expires")) + dateHeader, _ := http.ParseTime(res.Header.Get("Date")) + lastModifiedHeader, _ := http.ParseTime(res.Header.Get("Last-Modified")) + + obj := &cachedirective.Object{ + RespDirectives: resDir, + RespHeaders: res.Header, + RespStatusCode: res.StatusCode, + RespExpiresHeader: expiresHeader, + RespDateHeader: dateHeader, + RespLastModifiedHeader: lastModifiedHeader, + + ReqDirectives: reqDir, + ReqHeaders: res.Request.Header, + ReqMethod: res.Request.Method, + + NowUTC: time.Now().UTC(), + } + rv := cachedirective.ObjectResults{} + + cachedirective.CachableObject(obj, &rv) + cachedirective.ExpirationObject(obj, &rv) + + span.SetAttributes( + otel.WgResponseCacheControlReasons.String(fmt.Sprint(rv.OutReasons)), + otel.WgResponseCacheControlWarnings.String(fmt.Sprint(rv.OutWarnings)), + otel.WgResponseCacheControlExpiration.String(rv.OutExpirationTime.String()), + ) + + propagation.m.Lock() + defer propagation.m.Unlock() + + defaultResponseCache, _ := cachedirective.ParseResponseCacheControl(rule.Default) + defaultCacheControlObj := &cachedirective.Object{ + RespDirectives: defaultResponseCache, + } + + if propagation.previousCacheControl == nil { + if rule.Default != "" { + propagation.previousCacheControl = defaultCacheControlObj + propagation.header.Set(cacheControlKey, rule.Default) + } else { + propagation.previousCacheControl = obj + propagation.header.Set(cacheControlKey, res.Header.Get(cacheControlKey)) + return + } + } else if rule.Default != "" && isMoreRestrictive(defaultCacheControlObj, propagation.previousCacheControl) { + fmt.Println("Overwriting previous cache control with the current subgraph default") + propagation.previousCacheControl = defaultCacheControlObj + propagation.header.Set(cacheControlKey, rule.Default) + } + + if !expiresHeader.IsZero() && (propagation.previousCacheControl.RespExpiresHeader.IsZero() || expiresHeader.Before(propagation.previousCacheControl.RespExpiresHeader)) { + propagation.previousCacheControl = obj + propagation.header.Set("Expires", res.Header.Get("Expires")) + } + + // Compare the previous cache control with the current one to find the most restrictive + if isMoreRestrictive(propagation.previousCacheControl, obj) { + // Keep the previous cache control, which is more restrictive + fmt.Println("Keeping the previous cache control as it's more restrictive") + } else { + // The current cache control is more restrictive, so update it + fmt.Println("Updating to the current cache control as it's more restrictive") + propagation.previousCacheControl = obj + propagation.header.Set(cacheControlKey, res.Header.Get(cacheControlKey)) + } +} + +// isMoreRestrictive compares two cachedirective.Object instances and returns true if the first is more restrictive +func isMoreRestrictive(prev *cachedirective.Object, curr *cachedirective.Object) bool { + // Example comparison logic: check if "no-store" or "no-cache" are present, which are more restrictive + if prev.RespDirectives.NoStore || curr.RespDirectives.NoStore { + return true // No store is the most restrictive + } + if prev.RespDirectives.NoCachePresent && !curr.RespDirectives.NoCachePresent { + return true // No-cache is more restrictive than not having it + } + if curr.RespDirectives.NoCachePresent && !prev.RespDirectives.NoCachePresent { + return false // Current response has no-cache, which is more restrictive + } + + // Compare max-age: the shorter max-age is more restrictive + if prev.RespDirectives.MaxAge > 0 && curr.RespDirectives.MaxAge > 0 { + return prev.RespDirectives.MaxAge < curr.RespDirectives.MaxAge + } + + // If neither has max-age, but one has other expiration controls like Expires header, use that + if !prev.RespExpiresHeader.IsZero() && !curr.RespExpiresHeader.IsZero() { + return prev.RespExpiresHeader.Before(curr.RespExpiresHeader) + } + + // Fallback: if they are equal in restrictiveness, keep the previous one + return true } // SubgraphRules returns the list of header rules for the subgraph with the given name -func SubgraphRules(rules *config.HeaderRules, subgraphName string) []config.RequestHeaderRule { - var subgraphRules []config.RequestHeaderRule - subgraphRules = append(subgraphRules, rules.All.Request...) - subgraphRules = append(subgraphRules, rules.Subgraphs[subgraphName].Request...) +func SubgraphRules(rules *config.HeaderRules, subgraphName string) []*config.RequestHeaderRule { + if rules == nil { + return nil + } + var subgraphRules []*config.RequestHeaderRule + if rules.All != nil { + subgraphRules = append(subgraphRules, rules.All.Request...) + } + if rules.Subgraphs != nil { + if subgraphSpecificRules, ok := rules.Subgraphs[subgraphName]; ok { + subgraphRules = append(subgraphRules, subgraphSpecificRules.Request...) + } + } return subgraphRules } // FetchURLRules returns the list of header rules for first subgraph that matches the given URL -func FetchURLRules(rules *config.HeaderRules, subgraphs []*nodev1.Subgraph, routingURL string) []config.RequestHeaderRule { +func FetchURLRules(rules *config.HeaderRules, subgraphs []*nodev1.Subgraph, routingURL string) []*config.RequestHeaderRule { var subgraphName string for _, subgraph := range subgraphs { if subgraph.RoutingUrl == routingURL { @@ -201,7 +532,7 @@ func FetchURLRules(rules *config.HeaderRules, subgraphs []*nodev1.Subgraph, rout // PropagatedHeaders returns the list of header names and regular expressions // that will be propagated when applying the given rules. -func PropagatedHeaders(rules []config.RequestHeaderRule) (headerNames []string, headerNameRegexps []*regexp.Regexp, err error) { +func PropagatedHeaders(rules []*config.RequestHeaderRule) (headerNames []string, headerNameRegexps []*regexp.Regexp, err error) { for _, rule := range rules { switch rule.Operation { case config.HeaderRuleOperationPropagate: diff --git a/router/core/header_rule_engine_test.go b/router/core/header_rule_engine_test.go index 475836ea7a..08d42c6d03 100644 --- a/router/core/header_rule_engine_test.go +++ b/router/core/header_rule_engine_test.go @@ -18,9 +18,9 @@ func TestPropagateHeaderRule(t *testing.T) { t.Run("Should propagate with named header name / named", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Named: "X-Test-1", @@ -55,9 +55,9 @@ func TestPropagateHeaderRule(t *testing.T) { }) t.Run("Should propagate based on matching regex / matching", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Matching: "(?i)X-Test-.*", @@ -93,9 +93,9 @@ func TestPropagateHeaderRule(t *testing.T) { }) t.Run("Should propagate with default value / named + default", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Named: "X-Test-1", @@ -128,7 +128,7 @@ func TestPropagateHeaderRule(t *testing.T) { t.Run("Should not propagate as disallowed headers / named", func(t *testing.T) { - rules := []config.RequestHeaderRule{ + rules := []*config.RequestHeaderRule{ { Operation: "propagate", Named: "X-Test-1", @@ -136,14 +136,14 @@ func TestPropagateHeaderRule(t *testing.T) { } for _, name := range ignoredHeaders { - rules = append(rules, config.RequestHeaderRule{ + rules = append(rules, &config.RequestHeaderRule{ Operation: "propagate", Named: name, }) } - ht, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ Request: rules, }, }) @@ -180,9 +180,9 @@ func TestRenamePropagateHeaderRule(t *testing.T) { t.Run("Rename header / named", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Named: "X-Test-1", @@ -219,9 +219,9 @@ func TestRenamePropagateHeaderRule(t *testing.T) { t.Run("Rename based on matching regex pattern / matching", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Matching: "(?i)X-Test-.*", @@ -265,7 +265,7 @@ func TestRenamePropagateHeaderRule(t *testing.T) { t.Run("Should not rename to disallowed headers / named", func(t *testing.T) { - rules := []config.RequestHeaderRule{ + rules := []*config.RequestHeaderRule{ { Operation: "propagate", Named: "X-Test-Old", @@ -274,15 +274,15 @@ func TestRenamePropagateHeaderRule(t *testing.T) { } for _, name := range ignoredHeaders { - rules = append(rules, config.RequestHeaderRule{ + rules = append(rules, &config.RequestHeaderRule{ Operation: "propagate", Named: fmt.Sprintf("X-Test-%s", name), Rename: name, }) } - ht, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ Request: rules, }, }) @@ -316,9 +316,9 @@ func TestRenamePropagateHeaderRule(t *testing.T) { func TestSkipAllIgnoredHeaders(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Matching: "(?i).*", @@ -361,10 +361,10 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { t.Run("Should propagate set header / named", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Named: "X-Test-Subgraph", @@ -410,10 +410,10 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { }) t.Run("Should propagate set header / matching", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Matching: "(?i)X-Test-.*", @@ -458,7 +458,7 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { }) t.Run("Should not propagate disallowed header / named", func(t *testing.T) { - rules := []config.RequestHeaderRule{ + rules := []*config.RequestHeaderRule{ { Operation: "propagate", Named: "X-Test-Subgraph", @@ -466,14 +466,14 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { } for _, name := range ignoredHeaders { - rules = append(rules, config.RequestHeaderRule{ + rules = append(rules, &config.RequestHeaderRule{ Operation: "propagate", Named: name, }) } - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { Request: rules, }, @@ -521,7 +521,7 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { t.Run("Should not propagate disallowed headers / matching", func(t *testing.T) { - rules := []config.RequestHeaderRule{ + rules := []*config.RequestHeaderRule{ { Operation: "propagate", Matching: ".*", @@ -529,15 +529,15 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { } for _, name := range ignoredHeaders { - rules = append(rules, config.RequestHeaderRule{ + rules = append(rules, &config.RequestHeaderRule{ Operation: "propagate", Named: fmt.Sprintf("X-Test-%s", name), Rename: name, }) } - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { Request: rules, }, @@ -588,10 +588,10 @@ func TestSubgraphPropagateHeaderRule(t *testing.T) { func TestSubgraphRenamePropagateHeaderRule(t *testing.T) { t.Run("Should rename header / named", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Named: "X-Test-Subgraph", @@ -637,10 +637,10 @@ func TestSubgraphRenamePropagateHeaderRule(t *testing.T) { }) t.Run("Should fallback to default value when header value is not set / named", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Rename: "X-Test-Subgraph-Renamed-2", @@ -687,10 +687,10 @@ func TestSubgraphRenamePropagateHeaderRule(t *testing.T) { }) t.Run("Should rename header and don't fallback to default value when header is set / named", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Rename: "X-Test-Subgraph-Renamed", @@ -737,10 +737,10 @@ func TestSubgraphRenamePropagateHeaderRule(t *testing.T) { }) t.Run("Should rename headers based / matching rule", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Rename: "X-Test-Subgraph-Renamed", @@ -786,10 +786,10 @@ func TestSubgraphRenamePropagateHeaderRule(t *testing.T) { }) t.Run("Should rename headers and fallback to default value when header value is not set / matching rule", func(t *testing.T) { - ht, err := NewHeaderTransformer(config.HeaderRules{ - Subgraphs: map[string]config.GlobalHeaderRule{ + ht, err := NewHeaderPropagation(&config.HeaderRules{ + Subgraphs: map[string]*config.GlobalHeaderRule{ "subgraph-1": { - Request: []config.RequestHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Rename: "X-Test-Subgraph-Default-Renamed", @@ -838,9 +838,9 @@ func TestSubgraphRenamePropagateHeaderRule(t *testing.T) { func TestInvalidRegex(t *testing.T) { - _, err := NewHeaderTransformer(config.HeaderRules{ - All: config.GlobalHeaderRule{ - Request: []config.RequestHeaderRule{ + _, err := NewHeaderPropagation(&config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ { Operation: "propagate", Matching: "[", diff --git a/router/core/router.go b/router/core/router.go index ec33eec2bf..b6a7ab040c 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -172,8 +172,7 @@ type ( routerMiddlewares []func(http.Handler) http.Handler preOriginHandlers []TransportPreHandler postOriginHandlers []TransportPostHandler - headerRuleEngine *HeaderRuleEngine - headerRules config.HeaderRules + headerRules *config.HeaderRules subgraphTransportOptions *SubgraphTransportOptions graphqlMetricsConfig *GraphQLMetricsConfig routerTrafficConfig *config.RouterTrafficConfiguration @@ -310,14 +309,17 @@ func NewRouter(opts ...Option) (*Router, error) { r.livenessCheckPath = "/health/live" } - hr, err := NewHeaderTransformer(r.headerRules) + hr, err := NewHeaderPropagation(r.headerRules) if err != nil { return nil, err } - r.headerRuleEngine = hr - - r.preOriginHandlers = append(r.preOriginHandlers, r.headerRuleEngine.OnOriginRequest) + if hr.HasRequestRules() { + r.preOriginHandlers = append(r.preOriginHandlers, hr.OnOriginRequest) + } + if hr.HasResponseRules() { + r.postOriginHandlers = append(r.postOriginHandlers, hr.OnOriginResponse) + } defaultHeaders := []string{ // Common headers @@ -1448,7 +1450,7 @@ func WithEvents(cfg config.EventsConfiguration) Option { func WithHeaderRules(headers config.HeaderRules) Option { return func(r *Router) { - r.headerRules = headers + r.headerRules = &headers } } diff --git a/router/go.mod b/router/go.mod index 4b1c9b00a3..da408dd872 100644 --- a/router/go.mod +++ b/router/go.mod @@ -71,6 +71,7 @@ require ( github.com/fsnotify/fsnotify v1.7.0 github.com/klauspost/compress v1.17.9 github.com/minio/minio-go/v7 v7.0.74 + github.com/pquerna/cachecontrol v0.2.0 github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 github.com/wundergraph/astjson v0.0.0-20240910140849-bb15f94bd362 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/router/go.sum b/router/go.sum index b005e5aaec..c64a00f247 100644 --- a/router/go.sum +++ b/router/go.sum @@ -190,6 +190,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/pquerna/cachecontrol v0.2.0 h1:vBXSNuE5MYP9IJ5kjsdo8uq+w41jSPgvba2DEnkRx9k= +github.com/pquerna/cachecontrol v0.2.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= @@ -235,6 +237,7 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index ef39c46899..a59fb1edb7 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -164,13 +164,14 @@ type BackoffJitterRetry struct { type HeaderRules struct { // All is a set of rules that apply to all requests - All GlobalHeaderRule `yaml:"all,omitempty"` - Subgraphs map[string]GlobalHeaderRule `yaml:"subgraphs,omitempty"` + All *GlobalHeaderRule `yaml:"all,omitempty"` + Subgraphs map[string]*GlobalHeaderRule `yaml:"subgraphs,omitempty"` } type GlobalHeaderRule struct { // Request is a set of rules that apply to requests - Request []RequestHeaderRule `yaml:"request,omitempty"` + Request []*RequestHeaderRule `yaml:"request,omitempty"` + Response []*ResponseHeaderRule `yaml:"response,omitempty"` } type HeaderRuleOperation string @@ -179,6 +180,11 @@ const ( HeaderRuleOperationPropagate HeaderRuleOperation = "propagate" ) +type HeaderRule interface { + GetOperation() HeaderRuleOperation + GetMatching() string +} + type RequestHeaderRule struct { // Operation describes the header operation to perform e.g. "propagate" Operation HeaderRuleOperation `yaml:"op"` @@ -192,6 +198,50 @@ type RequestHeaderRule struct { Default string `yaml:"default"` } +func (r *RequestHeaderRule) GetOperation() HeaderRuleOperation { + return r.Operation +} + +func (r *RequestHeaderRule) GetMatching() string { + return r.Matching +} + +type ResponseHeaderRuleAlgorithm string + +const ( + // ResponseHeaderRuleAlgorithmFirstWrite propagates the first response header from a subgraph to the client + ResponseHeaderRuleAlgorithmFirstWrite ResponseHeaderRuleAlgorithm = "first_write" + // ResponseHeaderRuleAlgorithmLastWrite propagates the last response header from a subgraph to the client + ResponseHeaderRuleAlgorithmLastWrite ResponseHeaderRuleAlgorithm = "last_write" + // ResponseHeaderRuleAlgorithmAppend appends all response headers from all subgraphs to a comma separated list of values in the client response + ResponseHeaderRuleAlgorithmAppend ResponseHeaderRuleAlgorithm = "append" + // ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl propagates the most restrictive cache control header from all subgraph responses to the client + ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl ResponseHeaderRuleAlgorithm = "most_restrictive_cache_control" +) + +type ResponseHeaderRule struct { + // Operation describes the header operation to perform e.g. "propagate" + Operation HeaderRuleOperation `yaml:"op"` + // Matching is the regex to match the header name against + Matching string `yaml:"matching"` + // Named is the exact header name to match + Named string `yaml:"named"` + // Rename renames the header's key to the provided value + Rename string `yaml:"rename,omitempty"` + // Default is the default value to set if the header is not present + Default string `yaml:"default"` + // Algorithm is the algorithm to use when multiple headers are present + Algorithm ResponseHeaderRuleAlgorithm `yaml:"algorithm,omitempty"` +} + +func (r *ResponseHeaderRule) GetOperation() HeaderRuleOperation { + return r.Operation +} + +func (r *ResponseHeaderRule) GetMatching() string { + return r.Matching +} + type EngineDebugConfiguration struct { PrintOperationTransformations bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_OPERATION_TRANSFORMATIONS" yaml:"print_operation_transformations"` PrintOperationEnableASTRefs bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_OPERATION_ENABLE_AST_REFS" yaml:"print_operation_enable_ast_refs"` diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 20b6a746d3..90fa04c48b 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -960,6 +960,12 @@ "items": { "$ref": "#/definitions/traffic_shaping_header_rule" } + }, + "response": { + "type": "array", + "items": { + "$ref": "#/definitions/traffic_shaping_header_response_rule" + } } } }, @@ -974,6 +980,12 @@ "items": { "$ref": "#/definitions/traffic_shaping_header_rule" } + }, + "response": { + "type": "array", + "items": { + "$ref": "#/definitions/traffic_shaping_header_response_rule" + } } } } @@ -1654,6 +1666,66 @@ } }, "required": ["op"] + }, + "traffic_shaping_header_response_rule": { + "type": "object", + "description": "The configuration for all subgraphs. The configuration is used to configure the traffic shaping for all subgraphs.", + "additionalProperties": false, + "properties": { + "op": { + "type": "string", + "enum": ["propagate"], + "examples": ["propagate"], + "description": "The operation to perform on the header. The supported operations are 'propagate'. The 'propagate' operation is used to propagate the header to the subgraphs." + }, + "matching": { + "type": "string", + "examples": ["(?i)^X-Custom-.*"], + "description": "The matching rule for the header. The matching rule is a regular expression that is used to match the header. Can't be used with 'named'." + }, + "named": { + "type": "string", + "examples": ["X-Test-Header"], + "description": "The name of the header to match. Use the canonical version e.g. X-Test-Header. Can't be used with 'matching'." + }, + "rename": { + "type": "string", + "examples": ["X-Rename-Test-Header"], + "description": "Rename is used to rename the named or the matching headers. It can be used with either the named or the matching." + }, + "default": { + "type": "string", + "examples": ["default-value"], + "description": "The default value of the header in case it is not present in the request." + }, + "algorithm": { + "type": "string", + "enum": ["first_write", "last_write", "append", "most_restrictive_cache_control"], + "examples": ["first_write"], + "description": "The algorith, to use when multiple headers are present. The supported operations are '\"first_write\", \"last_write\", \"append\", and \"most_restrictive_cache_control\"'. The 'first_write' retains the first value of a given header. The 'last_write' retains the last value of a given header. The 'append' appends all values of a given header. The 'most_restrictive_cache_control' specifically focuses on the 'Cache-Control'/'Expiration' headers, and applies the most restrictive value encountered from a subgraph" + } + }, + "required": ["op", "algorithm"], + "if": { + "properties": { + "algorithm": { "const": "most_restrictive_cache_control" } + } + }, + "then": { + "properties": { + "matching": { "not": {} }, + "named": { "not": {} }, + "rename": { "not": {} } + }, + "required": ["op", "algorithm"] + }, + "else": { + "properties": { + "matching": {}, + "named": {}, + "rename": {} + } + } } } } diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index 0191dea062..5ac236c774 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -151,13 +151,19 @@ headers: - op: "propagate" named: "X-User-Id" default: "123" # Set the value when the header was not set - + response: + - op: "propagate" + algorithm: "append" + named: "X-Custom-Header" subgraphs: specific-subgraph: # Will only affect this subgraph request: - op: "propagate" named: Subgraph-Secret default: "some-secret" + response: + - op: "propagate" + algorithm: "most_restrictive_cache_control" # Authentication and Authorization # See https://cosmo-docs.wundergraph.com/router/authentication-and-authorization for more information diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 8e2773ea78..c05d6439c7 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -82,9 +82,7 @@ }, "Modules": null, "Headers": { - "All": { - "Request": null - }, + "All": null, "Subgraphs": null }, "TrafficShaping": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 505313f78f..c036d28494 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -127,6 +127,16 @@ "Rename": "", "Default": "123" } + ], + "Response": [ + { + "Operation": "propagate", + "Matching": "", + "Named": "X-Custom-Header", + "Rename": "", + "Default": "", + "Algorithm": "append" + } ] }, "Subgraphs": { @@ -139,6 +149,16 @@ "Rename": "", "Default": "some-secret" } + ], + "Response": [ + { + "Operation": "propagate", + "Matching": "", + "Named": "", + "Rename": "", + "Default": "", + "Algorithm": "most_restrictive_cache_control" + } ] } } diff --git a/router/pkg/otel/attributes.go b/router/pkg/otel/attributes.go index bc042e465d..5fd4f8b82d 100644 --- a/router/pkg/otel/attributes.go +++ b/router/pkg/otel/attributes.go @@ -35,6 +35,9 @@ const ( WgVariablesValidationSkipped = attribute.Key("wg.engine.variables_validation_skipped") WgQueryDepth = attribute.Key("wg.operation.query_depth") WgQueryDepthCacheHit = attribute.Key("wg.operation.query_depth_cache_hit") + WgResponseCacheControlReasons = attribute.Key("wg.operation.cache_control_reasons") + WgResponseCacheControlWarnings = attribute.Key("wg.operation.cache_control_warnings") + WgResponseCacheControlExpiration = attribute.Key("wg.operation.cache_control_expiration") // HTTPRequestUploadFileCount is the number of files uploaded in a request (Not specified in the OpenTelemetry specification) HTTPRequestUploadFileCount = attribute.Key("http.request.upload.file_count") )