diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index a9cb854e96..8cd3643dff 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -11,11 +11,13 @@ import ( "encoding/binary" "fmt" "net/http" + "strings" "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" "golang.org/x/sync/errgroup" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" @@ -90,6 +92,8 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D // It processes the input JSON data to make gRPC calls and returns // the response data. // +// Headers are converted to gRPC metadata and part of gRPC calls. +// // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { @@ -111,6 +115,19 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil } + // convert headers to grpc metadata and attach to ctx + if len(headers) > 0 { + // assume that each header has exactly one value for default pairs size + pairs := make([]string, 0, len(headers)*2) + for headerName, headerValues := range headers { + headerName = strings.ToLower(headerName) + for _, v := range headerValues { + pairs = append(pairs, headerName, v) + } + } + ctx = metadata.AppendToOutgoingContext(ctx, pairs...) + } + graph := NewDependencyGraph(d.plan) root := astjson.ObjectValue(nil) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 1f3b53a6d4..5846a2ab64 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -6,6 +6,7 @@ import ( "fmt" "math" "net" + "net/http" "strings" "testing" @@ -13,6 +14,7 @@ import ( "github.com/tidwall/gjson" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/encoding/protojson" protoref "google.golang.org/protobuf/reflect/protoreflect" @@ -5237,3 +5239,276 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) { }) } } + +func Test_Datasource_Load_WithHeaders(t *testing.T) { + conn, cleanup := setupTestGRPCServer(t) + t.Cleanup(cleanup) + + type graphqlError struct { + Message string `json:"message"` + } + type graphqlResponse struct { + Data map[string]interface{} `json:"data"` + Errors []graphqlError `json:"errors,omitempty"` + } + + testCases := []struct { + name string + query string + vars string + headers http.Header + validate func(t *testing.T, data map[string]interface{}) + validateError func(t *testing.T, errData []graphqlError) + }{ + { + name: "QueryUser with header override", + query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`, + vars: `{"variables":{"id":"original-user-123"}}`, + headers: func() http.Header { + h := make(http.Header) + h.Set("X-User-ID", "header-user-42") + return h + }(), + validate: func(t *testing.T, data map[string]interface{}) { + user, ok := data["user"].(map[string]interface{}) + require.True(t, ok, "user should be an object") + require.Equal(t, "header-user-42", user["id"], "user ID should come from header") + require.Equal(t, "User header-user-42", user["name"], "user name should use header-derived ID") + }, + validateError: func(t *testing.T, errData []graphqlError) { + require.Empty(t, errData) + }, + }, + { + name: "QueryUser with header triggering error", + query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`, + vars: `{"variables":{"id":"valid-user-123"}}`, + headers: func() http.Header { + h := make(http.Header) + h.Set("X-User-ID", "error-user") + return h + }(), + validate: func(t *testing.T, data map[string]interface{}) { + // Data might be present but should have errors + }, + validateError: func(t *testing.T, errData []graphqlError) { + require.NotEmpty(t, errData, "should have errors") + require.Contains(t, errData[0].Message, "user not found: error-user") + }, + }, + { + name: "QueryUser without headers (nil) - baseline behavior", + query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`, + vars: `{"variables":{"id":"baseline-user-99"}}`, + headers: nil, + validate: func(t *testing.T, data map[string]interface{}) { + user, ok := data["user"].(map[string]interface{}) + require.True(t, ok, "user should be an object") + require.Equal(t, "baseline-user-99", user["id"], "user ID should come from query variable") + require.Equal(t, "User baseline-user-99", user["name"], "user name should use variable-derived ID") + }, + validateError: func(t *testing.T, errData []graphqlError) { + require.Empty(t, errData) + }, + }, + { + name: "QueryUsers with custom prefix header", + query: `query UsersQuery { users { id name } }`, + vars: `{"variables":{}}`, + headers: func() http.Header { + h := make(http.Header) + h.Set("X-User-Prefix", "Admin") + return h + }(), + validate: func(t *testing.T, data map[string]interface{}) { + users, ok := data["users"].([]interface{}) + require.True(t, ok, "users should be an array") + require.Len(t, users, 3, "should return 3 users") + + for i, u := range users { + user, ok := u.(map[string]interface{}) + require.True(t, ok, "each user should be an object") + require.Equal(t, fmt.Sprintf("user-%d", i+1), user["id"]) + require.Equal(t, fmt.Sprintf("Admin %d", i+1), user["name"], "user name should use custom prefix from header") + } + }, + validateError: func(t *testing.T, errData []graphqlError) { + require.Empty(t, errData) + }, + }, + { + name: "MutationCreateUser with name override header", + query: `mutation CreateUser($input: UserInput!) { createUser(input: $input) { id name } }`, + vars: `{"variables":{"input":{"name":"OriginalName"}}}`, + headers: func() http.Header { + h := make(http.Header) + h.Set("X-Custom-Name", "HeaderName") + return h + }(), + validate: func(t *testing.T, data map[string]interface{}) { + createUser, ok := data["createUser"].(map[string]interface{}) + require.True(t, ok, "createUser should be an object") + require.NotEmpty(t, createUser["id"], "created user should have an ID") + require.Equal(t, "HeaderName", createUser["name"], "created user name should come from header") + }, + validateError: func(t *testing.T, errData []graphqlError) { + require.Empty(t, errData) + }, + }, + { + name: "Categories with productCount field resolver and header offset", + query: `query CategoriesWithProductCount($filters: ProductCountFilter) { categories { id name kind productCount(filters: $filters) } }`, + vars: `{"variables":{"filters":{"minPrice":100}}}`, + headers: func() http.Header { + h := make(http.Header) + h.Set("X-Count-Offset", "100") + return h + }(), + validate: func(t *testing.T, data map[string]interface{}) { + categories, ok := data["categories"].([]interface{}) + require.True(t, ok, "categories should be an array") + require.Len(t, categories, 4, "should return 4 categories") + + // Verify that productCount for each category is offset by 100 + expectedCounts := []float64{100, 101, 102, 103} + for i, c := range categories { + category, ok := c.(map[string]interface{}) + require.True(t, ok, "category should be an object") + require.NotEmpty(t, category["id"]) + require.NotEmpty(t, category["name"]) + require.Equal(t, expectedCounts[i], category["productCount"], "productCount should be offset by header value") + } + }, + validateError: func(t *testing.T, errData []graphqlError) { + require.Empty(t, errData) + }, + }, + { + name: "Categories with productCount without headers - baseline behavior", + query: `query CategoriesWithProductCount($filters: ProductCountFilter) { categories { id name kind productCount(filters: $filters) } }`, + vars: `{"variables":{"filters":{"minPrice":100}}}`, + headers: nil, + validate: func(t *testing.T, data map[string]interface{}) { + categories, ok := data["categories"].([]interface{}) + require.True(t, ok, "categories should be an array") + require.Len(t, categories, 4, "should return 4 categories") + + // Verify default productCount values (no offset) + expectedCounts := []float64{0, 1, 2, 3} + for i, c := range categories { + category, ok := c.(map[string]interface{}) + require.True(t, ok, "category should be an object") + require.NotEmpty(t, category["id"]) + require.NotEmpty(t, category["name"]) + require.Equal(t, expectedCounts[i], category["productCount"], "productCount should use default values without header") + } + }, + validateError: func(t *testing.T, errData []graphqlError) { + require.Empty(t, errData) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Parse the GraphQL schema + schemaDoc := grpctest.MustGraphQLSchema(t) + + // Parse the GraphQL query + queryDoc, report := astparser.ParseGraphqlDocumentString(tc.query) + require.False(t, report.HasErrors(), "failed to parse query: %s", report.Error()) + + compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) + require.NoError(t, err) + + // Create the datasource + ds, err := NewDataSource(conn, DataSourceConfig{ + Operation: &queryDoc, + Definition: &schemaDoc, + SubgraphName: "Products", + Mapping: testMapping(), + Compiler: compiler, + }) + require.NoError(t, err) + + // Execute the query with headers + input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) + output, err := ds.Load(context.Background(), tc.headers, []byte(input)) + require.NoError(t, err) + + // Parse the response + var resp graphqlResponse + err = json.Unmarshal(output, &resp) + require.NoError(t, err, "Failed to unmarshal response") + + tc.validate(t, resp.Data) + tc.validateError(t, resp.Errors) + }) + } +} + +func Test_Datasource_Load_PreservesExistingContextMetadata(t *testing.T) { + conn, cleanup := setupTestGRPCServer(t) + t.Cleanup(cleanup) + + type graphqlError struct { + Message string `json:"message"` + } + type graphqlResponse struct { + Data map[string]interface{} `json:"data"` + Errors []graphqlError `json:"errors,omitempty"` + } + + // Parse the GraphQL schema + schemaDoc := grpctest.MustGraphQLSchema(t) + + query := `query UserQuery($id: ID!) { user(id: $id) { id name } }` + vars := `{"variables":{"id":"test-user-123"}}` + + // Parse the GraphQL query + queryDoc, report := astparser.ParseGraphqlDocumentString(query) + require.False(t, report.HasErrors(), "failed to parse query: %s", report.Error()) + + compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) + require.NoError(t, err) + + // Create the datasource + ds, err := NewDataSource(conn, DataSourceConfig{ + Operation: &queryDoc, + Definition: &schemaDoc, + SubgraphName: "Products", + Mapping: testMapping(), + Compiler: compiler, + }) + require.NoError(t, err) + + // Create a context with existing metadata + ctx := metadata.NewOutgoingContext( + context.Background(), + metadata.Pairs("x-existing-key", "existing-value"), + ) + + // Create HTTP headers to be forwarded + headers := make(http.Header) + headers.Set("X-User-ID", "header-user-456") + + // Execute the query with both existing context metadata and new HTTP headers + input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, vars) + output, err := ds.Load(ctx, headers, []byte(input)) + require.NoError(t, err) + + // Parse the response + var resp graphqlResponse + err = json.Unmarshal(output, &resp) + require.NoError(t, err, "Failed to unmarshal response") + + // Verify no errors + require.Empty(t, resp.Errors, "Should not have GraphQL errors") + + // Verify the response includes both the header-derived ID and the existing metadata value + user, ok := resp.Data["user"].(map[string]interface{}) + require.True(t, ok, "user should be an object") + require.Equal(t, "header-user-456", user["id"], "user ID should come from HTTP header") + require.Equal(t, "User header-user-456 (existing: existing-value)", user["name"], + "user name should include both header-derived ID and existing context metadata") +} diff --git a/v2/pkg/grpctest/mockservice.go b/v2/pkg/grpctest/mockservice.go index ae9d45ffbc..7dbc5b9605 100644 --- a/v2/pkg/grpctest/mockservice.go +++ b/v2/pkg/grpctest/mockservice.go @@ -7,6 +7,7 @@ import ( "strconv" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/wrapperspb" @@ -209,11 +210,22 @@ func (s *MockService) QuerySearch(ctx context.Context, in *productv1.QuerySearch func (s *MockService) QueryUsers(ctx context.Context, in *productv1.QueryUsersRequest) (*productv1.QueryUsersResponse, error) { var results []*productv1.User + // Default prefix for user names and IDs + prefix := "User" + + // Allow header to override the prefix + md, ok := metadata.FromIncomingContext(ctx) + if ok { + if values := md.Get("x-user-prefix"); len(values) > 0 { + prefix = values[0] + } + } + // Generate 3 random users for i := 1; i <= 3; i++ { results = append(results, &productv1.User{ Id: fmt.Sprintf("user-%d", i), - Name: fmt.Sprintf("User %d", i), + Name: fmt.Sprintf("%s %d", prefix, i), }) } @@ -224,16 +236,33 @@ func (s *MockService) QueryUsers(ctx context.Context, in *productv1.QueryUsersRe func (s *MockService) QueryUser(ctx context.Context, in *productv1.QueryUserRequest) (*productv1.QueryUserResponse, error) { userId := in.GetId() + existingValue := "" + md, ok := metadata.FromIncomingContext(ctx) + if ok { + if values := md.Get("x-user-id"); len(values) > 0 { + userId = values[0] + } + // Check for existing metadata that was in the context before headers were added + if values := md.Get("x-existing-key"); len(values) > 0 { + existingValue = values[0] + } + } // Return a gRPC status error for a specific test case if userId == "error-user" { return nil, status.Errorf(codes.NotFound, "user not found: %s", userId) } + // Include existing metadata value in the name if present + userName := fmt.Sprintf("User %s", userId) + if existingValue != "" { + userName = fmt.Sprintf("User %s (existing: %s)", userId, existingValue) + } + return &productv1.QueryUserResponse{ User: &productv1.User{ Id: userId, - Name: fmt.Sprintf("User %s", userId), + Name: userName, }, }, nil } @@ -515,11 +544,20 @@ func (s *MockService) QueryAllPets(ctx context.Context, in *productv1.QueryAllPe // Implementation for CreateUser mutation func (s *MockService) MutationCreateUser(ctx context.Context, in *productv1.MutationCreateUserRequest) (*productv1.MutationCreateUserResponse, error) { input := in.GetInput() + name := input.GetName() + + // Allow header to override the name + md, ok := metadata.FromIncomingContext(ctx) + if ok { + if values := md.Get("x-custom-name"); len(values) > 0 { + name = values[0] + } + } // Create a new user with the input name and a random ID user := &productv1.User{ Id: fmt.Sprintf("user-%d", rand.Intn(1000)), - Name: input.GetName(), + Name: name, } return &productv1.MutationCreateUserResponse{ diff --git a/v2/pkg/grpctest/mockservice_resolve.go b/v2/pkg/grpctest/mockservice_resolve.go index 44e0aa92aa..fea44506af 100644 --- a/v2/pkg/grpctest/mockservice_resolve.go +++ b/v2/pkg/grpctest/mockservice_resolve.go @@ -3,7 +3,9 @@ package grpctest import ( "context" "fmt" + "strconv" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/wrapperspb" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest/productv1" @@ -828,11 +830,25 @@ func (s *MockService) ResolveSubcategoryItemCount(_ context.Context, req *produc } // ResolveCategoryProductCount implements productv1.ProductServiceServer. -func (s *MockService) ResolveCategoryProductCount(_ context.Context, req *productv1.ResolveCategoryProductCountRequest) (*productv1.ResolveCategoryProductCountResponse, error) { +func (s *MockService) ResolveCategoryProductCount(ctx context.Context, req *productv1.ResolveCategoryProductCountRequest) (*productv1.ResolveCategoryProductCountResponse, error) { results := make([]*productv1.ResolveCategoryProductCountResult, 0, len(req.GetContext())) + + // Default offset is 0 + offset := int32(0) + + // Allow header to override the offset + md, ok := metadata.FromIncomingContext(ctx) + if ok { + if values := md.Get("x-count-offset"); len(values) > 0 { + if v, err := strconv.Atoi(values[0]); err == nil { + offset = int32(v) + } + } + } + for i := range req.GetContext() { results = append(results, &productv1.ResolveCategoryProductCountResult{ - ProductCount: int32(i), + ProductCount: int32(i) + offset, }) }