Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (
"encoding/binary"
"fmt"
"net/http"
"strings"

"github.com/cespare/xxhash/v2"
"github.com/tidwall/gjson"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/wundergraph/astjson"
"github.com/wundergraph/go-arena"
Expand Down Expand Up @@ -90,6 +92,8 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D
// It processes the input JSON data to make gRPC calls and returns
// the response data.
//
// Headers are converted to gRPC metadata and part of gRPC calls.
//
// The input is expected to contain the necessary information to make
// a gRPC call, including service name, method name, and request data.
func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) {
Expand All @@ -111,6 +115,19 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte
return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil
}

// convert headers to grpc metadata and attach to ctx
if len(headers) > 0 {
// assume that each header has exactly one value for default pairs size
Comment thread
jensneuse marked this conversation as resolved.
pairs := make([]string, 0, len(headers)*2)
for headerName, headerValues := range headers {
headerName = strings.ToLower(headerName)
for _, v := range headerValues {
pairs = append(pairs, headerName, v)
}
}
ctx = metadata.AppendToOutgoingContext(ctx, pairs...)
}

graph := NewDependencyGraph(d.plan)

root := astjson.ObjectValue(nil)
Expand Down
275 changes: 275 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ import (
"fmt"
"math"
"net"
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/encoding/protojson"
protoref "google.golang.org/protobuf/reflect/protoreflect"
Expand Down Expand Up @@ -5237,3 +5239,276 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) {
})
}
}

func Test_Datasource_Load_WithHeaders(t *testing.T) {
conn, cleanup := setupTestGRPCServer(t)
t.Cleanup(cleanup)

type graphqlError struct {
Message string `json:"message"`
}
type graphqlResponse struct {
Data map[string]interface{} `json:"data"`
Errors []graphqlError `json:"errors,omitempty"`
}

testCases := []struct {
name string
query string
vars string
headers http.Header
validate func(t *testing.T, data map[string]interface{})
validateError func(t *testing.T, errData []graphqlError)
}{
{
name: "QueryUser with header override",
query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`,
vars: `{"variables":{"id":"original-user-123"}}`,
headers: func() http.Header {
h := make(http.Header)
h.Set("X-User-ID", "header-user-42")
return h
}(),
validate: func(t *testing.T, data map[string]interface{}) {
user, ok := data["user"].(map[string]interface{})
require.True(t, ok, "user should be an object")
require.Equal(t, "header-user-42", user["id"], "user ID should come from header")
require.Equal(t, "User header-user-42", user["name"], "user name should use header-derived ID")
},
validateError: func(t *testing.T, errData []graphqlError) {
require.Empty(t, errData)
},
},
{
name: "QueryUser with header triggering error",
query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`,
vars: `{"variables":{"id":"valid-user-123"}}`,
headers: func() http.Header {
h := make(http.Header)
h.Set("X-User-ID", "error-user")
return h
}(),
validate: func(t *testing.T, data map[string]interface{}) {
// Data might be present but should have errors
},
validateError: func(t *testing.T, errData []graphqlError) {
require.NotEmpty(t, errData, "should have errors")
require.Contains(t, errData[0].Message, "user not found: error-user")
},
},
{
name: "QueryUser without headers (nil) - baseline behavior",
query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`,
vars: `{"variables":{"id":"baseline-user-99"}}`,
headers: nil,
validate: func(t *testing.T, data map[string]interface{}) {
user, ok := data["user"].(map[string]interface{})
require.True(t, ok, "user should be an object")
require.Equal(t, "baseline-user-99", user["id"], "user ID should come from query variable")
require.Equal(t, "User baseline-user-99", user["name"], "user name should use variable-derived ID")
},
validateError: func(t *testing.T, errData []graphqlError) {
require.Empty(t, errData)
},
},
{
name: "QueryUsers with custom prefix header",
query: `query UsersQuery { users { id name } }`,
vars: `{"variables":{}}`,
headers: func() http.Header {
h := make(http.Header)
h.Set("X-User-Prefix", "Admin")
return h
}(),
validate: func(t *testing.T, data map[string]interface{}) {
users, ok := data["users"].([]interface{})
require.True(t, ok, "users should be an array")
require.Len(t, users, 3, "should return 3 users")

for i, u := range users {
user, ok := u.(map[string]interface{})
require.True(t, ok, "each user should be an object")
require.Equal(t, fmt.Sprintf("user-%d", i+1), user["id"])
require.Equal(t, fmt.Sprintf("Admin %d", i+1), user["name"], "user name should use custom prefix from header")
}
},
validateError: func(t *testing.T, errData []graphqlError) {
require.Empty(t, errData)
},
},
{
name: "MutationCreateUser with name override header",
query: `mutation CreateUser($input: UserInput!) { createUser(input: $input) { id name } }`,
vars: `{"variables":{"input":{"name":"OriginalName"}}}`,
headers: func() http.Header {
h := make(http.Header)
h.Set("X-Custom-Name", "HeaderName")
return h
}(),
validate: func(t *testing.T, data map[string]interface{}) {
createUser, ok := data["createUser"].(map[string]interface{})
require.True(t, ok, "createUser should be an object")
require.NotEmpty(t, createUser["id"], "created user should have an ID")
require.Equal(t, "HeaderName", createUser["name"], "created user name should come from header")
},
validateError: func(t *testing.T, errData []graphqlError) {
require.Empty(t, errData)
},
},
{
name: "Categories with productCount field resolver and header offset",
query: `query CategoriesWithProductCount($filters: ProductCountFilter) { categories { id name kind productCount(filters: $filters) } }`,
vars: `{"variables":{"filters":{"minPrice":100}}}`,
headers: func() http.Header {
h := make(http.Header)
h.Set("X-Count-Offset", "100")
return h
}(),
validate: func(t *testing.T, data map[string]interface{}) {
categories, ok := data["categories"].([]interface{})
require.True(t, ok, "categories should be an array")
require.Len(t, categories, 4, "should return 4 categories")

// Verify that productCount for each category is offset by 100
expectedCounts := []float64{100, 101, 102, 103}
for i, c := range categories {
category, ok := c.(map[string]interface{})
require.True(t, ok, "category should be an object")
require.NotEmpty(t, category["id"])
require.NotEmpty(t, category["name"])
require.Equal(t, expectedCounts[i], category["productCount"], "productCount should be offset by header value")
}
},
validateError: func(t *testing.T, errData []graphqlError) {
require.Empty(t, errData)
},
},
{
name: "Categories with productCount without headers - baseline behavior",
query: `query CategoriesWithProductCount($filters: ProductCountFilter) { categories { id name kind productCount(filters: $filters) } }`,
vars: `{"variables":{"filters":{"minPrice":100}}}`,
headers: nil,
validate: func(t *testing.T, data map[string]interface{}) {
categories, ok := data["categories"].([]interface{})
require.True(t, ok, "categories should be an array")
require.Len(t, categories, 4, "should return 4 categories")

// Verify default productCount values (no offset)
expectedCounts := []float64{0, 1, 2, 3}
for i, c := range categories {
category, ok := c.(map[string]interface{})
require.True(t, ok, "category should be an object")
require.NotEmpty(t, category["id"])
require.NotEmpty(t, category["name"])
require.Equal(t, expectedCounts[i], category["productCount"], "productCount should use default values without header")
}
},
validateError: func(t *testing.T, errData []graphqlError) {
require.Empty(t, errData)
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Parse the GraphQL schema
schemaDoc := grpctest.MustGraphQLSchema(t)

// Parse the GraphQL query
queryDoc, report := astparser.ParseGraphqlDocumentString(tc.query)
require.False(t, report.HasErrors(), "failed to parse query: %s", report.Error())

compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping())
require.NoError(t, err)

// Create the datasource
ds, err := NewDataSource(conn, DataSourceConfig{
Operation: &queryDoc,
Definition: &schemaDoc,
SubgraphName: "Products",
Mapping: testMapping(),
Compiler: compiler,
})
require.NoError(t, err)

// Execute the query with headers
input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars)
output, err := ds.Load(context.Background(), tc.headers, []byte(input))
require.NoError(t, err)

// Parse the response
var resp graphqlResponse
err = json.Unmarshal(output, &resp)
require.NoError(t, err, "Failed to unmarshal response")

tc.validate(t, resp.Data)
tc.validateError(t, resp.Errors)
})
}
}

func Test_Datasource_Load_PreservesExistingContextMetadata(t *testing.T) {
conn, cleanup := setupTestGRPCServer(t)
t.Cleanup(cleanup)

type graphqlError struct {
Message string `json:"message"`
}
type graphqlResponse struct {
Data map[string]interface{} `json:"data"`
Errors []graphqlError `json:"errors,omitempty"`
}

// Parse the GraphQL schema
schemaDoc := grpctest.MustGraphQLSchema(t)

query := `query UserQuery($id: ID!) { user(id: $id) { id name } }`
vars := `{"variables":{"id":"test-user-123"}}`

// Parse the GraphQL query
queryDoc, report := astparser.ParseGraphqlDocumentString(query)
require.False(t, report.HasErrors(), "failed to parse query: %s", report.Error())

compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping())
require.NoError(t, err)

// Create the datasource
ds, err := NewDataSource(conn, DataSourceConfig{
Operation: &queryDoc,
Definition: &schemaDoc,
SubgraphName: "Products",
Mapping: testMapping(),
Compiler: compiler,
})
require.NoError(t, err)

// Create a context with existing metadata
ctx := metadata.NewOutgoingContext(
context.Background(),
metadata.Pairs("x-existing-key", "existing-value"),
)

// Create HTTP headers to be forwarded
headers := make(http.Header)
headers.Set("X-User-ID", "header-user-456")

// Execute the query with both existing context metadata and new HTTP headers
input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, vars)
output, err := ds.Load(ctx, headers, []byte(input))
require.NoError(t, err)

// Parse the response
var resp graphqlResponse
err = json.Unmarshal(output, &resp)
require.NoError(t, err, "Failed to unmarshal response")

// Verify no errors
require.Empty(t, resp.Errors, "Should not have GraphQL errors")

// Verify the response includes both the header-derived ID and the existing metadata value
user, ok := resp.Data["user"].(map[string]interface{})
require.True(t, ok, "user should be an object")
require.Equal(t, "header-user-456", user["id"], "user ID should come from HTTP header")
require.Equal(t, "User header-user-456 (existing: existing-value)", user["name"],
"user name should include both header-derived ID and existing context metadata")
}
Loading