From 5f81aac3d4c5bc0e1ef114c3f288caf177e99d20 Mon Sep 17 00:00:00 2001 From: Jatin Dev <64803093+JatinDevDG@users.noreply.github.com> Date: Mon, 20 Jul 2020 21:05:00 +0530 Subject: [PATCH] feat(GraphQL): Adds auth for subscriptions (#5984) * removed log info * added missing functions in auth.go * resolved merge issues * added authvariable code for subscriptions * removed a print statement * removed authvariable function and changed it's occurences * resolved review comments * replaced gophers library to our own * added test * Fix build error by using latest version of dgraph-io/graphql-transport-ws * Fix the audience unit tests * Fix formatting * fix for subscription withou auth * fixed subscription error * added auth test * resolved pawan's comments. * latest code. * Delete p,w and zw from checked in code * Use jwt.At function so that time.IsZero returns correctly * modified test * Complete the test * Add another test cases to check subscription ends after JWT expires * added test for multiple subscriptions * added test for multiple subscriptions with different jw data * clean code * modified subscriptions test * Fix the test * addresses review comments. * edited few comments * Remove subscriptionID from subscriber * Run go mod tidy to fix the go.mod and go.sum files * Clean up some code * Revert "Clean up some code" This reverts commit 3922649e3f2c224100186acfd1a6f4bfc0e8c4d9. * Fix failures in custom logic test * fixed deepsource error * fixed auth test Co-authored-by: Pawan Rawal Co-authored-by: JatinDevDG --- go.mod | 5 +- go.sum | 8 +- graphql/authorization/auth.go | 39 +- graphql/e2e/auth/auth_test.go | 3 +- graphql/e2e/common/query.go | 10 +- graphql/e2e/common/subscription.go | 5 +- graphql/e2e/custom_logic/custom_logic_test.go | 22 +- graphql/e2e/subscription/subscription_test.go | 635 +++++++++++++++++- graphql/resolve/auth_test.go | 8 +- graphql/resolve/mutation.go | 3 +- graphql/resolve/mutation_rewriter.go | 26 +- graphql/resolve/query_rewriter.go | 11 +- graphql/subscription/poller.go | 93 ++- graphql/web/http.go | 49 +- testutil/graphql.go | 7 +- 15 files changed, 825 insertions(+), 99 deletions(-) diff --git a/go.mod b/go.mod index bc259071435..d5b7fee9e35 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd github.com/dgraph-io/badger/v2 v2.0.1-rc1.0.20200715005050-3ffaf3cd1d4a github.com/dgraph-io/dgo/v200 v200.0.0-20200401175452-e463f9234453 + github.com/dgraph-io/graphql-transport-ws v0.0.0-20200715131837-c0460019ead2 github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 @@ -36,7 +37,7 @@ require ( github.com/google/uuid v1.0.0 github.com/gorilla/websocket v1.4.1 github.com/graph-gophers/graphql-go v0.0.0-20200309224638-dae41bde9ef9 - github.com/graph-gophers/graphql-transport-ws v0.0.0-20190611222414-40c048432299 + github.com/graph-gophers/graphql-transport-ws v0.0.0-20190611222414-40c048432299 // indirect github.com/hashicorp/vault/api v1.0.4 github.com/minio/minio-go/v6 v6.0.55 github.com/mitchellh/panicwrap v1.0.0 @@ -62,7 +63,7 @@ require ( golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e golang.org/x/sync v0.0.0-20190423024810-112230192c58 - golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f + golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9 golang.org/x/text v0.3.2 google.golang.org/genproto v0.0.0-20190516172635-bb713bdc0e52 // indirect google.golang.org/grpc v1.23.0 diff --git a/go.sum b/go.sum index 3dab6944453..90db9b31f61 100644 --- a/go.sum +++ b/go.sum @@ -78,12 +78,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.0 h1:DshxFxZWXUcO0xX476VJC07Xsr6ZCBVRHKZ93Oh7Evo= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= -github.com/dgraph-io/badger/v2 v2.0.1-rc1.0.20200711090415-0dfb8b45d4a4 h1:M85aROJxDOXmp8mZQN6x04EraKDGOF2RL3ox/6LZdjo= -github.com/dgraph-io/badger/v2 v2.0.1-rc1.0.20200711090415-0dfb8b45d4a4/go.mod h1:26P/7fbL4kUZVEVKLAKXkBXKOydDmM2p1e+NhhnBCAE= github.com/dgraph-io/badger/v2 v2.0.1-rc1.0.20200715005050-3ffaf3cd1d4a h1:k8A4B5IEFzH33Px1N7K0FpMszEFTZGQb7HCRBslMB+s= github.com/dgraph-io/badger/v2 v2.0.1-rc1.0.20200715005050-3ffaf3cd1d4a/go.mod h1:26P/7fbL4kUZVEVKLAKXkBXKOydDmM2p1e+NhhnBCAE= github.com/dgraph-io/dgo/v200 v200.0.0-20200401175452-e463f9234453 h1:DTgOrw91nMIukDm/WEvdobPLl0LgeDd/JE66+24jBks= github.com/dgraph-io/dgo/v200 v200.0.0-20200401175452-e463f9234453/go.mod h1:Co+FwJrnndSrPORO8Gdn20dR7FPTfmXr0W/su0Ve/Ig= +github.com/dgraph-io/graphql-transport-ws v0.0.0-20200715131837-c0460019ead2 h1:NSl3XXyON9bgmBJSAvr5FPrgILAovtoTs7FwdtaZZq0= +github.com/dgraph-io/graphql-transport-ws v0.0.0-20200715131837-c0460019ead2/go.mod h1:7z3c/5w0sMYYZF5bHsrh8IH4fKwG5O5Y70cPH1ZLLRQ= github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de h1:t0UHb5vdojIDUqktM6+xJAfScFBsVpXZmqC9dsgJmeA= github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= @@ -507,8 +507,8 @@ golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8= -golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9 h1:YTzHMGlqJu67/uEo1lBv0n3wBXhXNeUbB1XfN2vmTm0= +golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20181227161524-e6919f6577db/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= diff --git a/graphql/authorization/auth.go b/graphql/authorization/auth.go index 06a69dc460b..ea7a3fed2a0 100644 --- a/graphql/authorization/auth.go +++ b/graphql/authorization/auth.go @@ -35,9 +35,11 @@ import ( ) type ctxKey string +type authVariablekey string const ( AuthJwtCtxKey = ctxKey("authorizationJwt") + AuthVariables = authVariablekey("authVariable") RSA256 = "RS256" HMAC256 = "HS256" AuthMetaHeader = "# Dgraph.Authorization " @@ -207,22 +209,6 @@ func (c *CustomClaims) UnmarshalJSON(data []byte) error { return nil } -func ExtractAuthVariables(ctx context.Context) (map[string]interface{}, error) { - // Extract the jwt and unmarshal the jwt to get the auth variables. - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, nil - } - - jwtToken := md.Get(string(AuthJwtCtxKey)) - if len(jwtToken) == 0 { - return nil, nil - } else if len(jwtToken) > 1 { - return nil, fmt.Errorf("invalid jwt auth token") - } - return validateToken(jwtToken[0]) -} - func (c *CustomClaims) validateAudience() error { // If there's no audience claim, ignore if c.Audience == nil || len(c.Audience) == 0 { @@ -249,7 +235,23 @@ func (c *CustomClaims) validateAudience() error { return nil } -func validateToken(jwtStr string) (map[string]interface{}, error) { +func ExtractCustomClaims(ctx context.Context) (*CustomClaims, error) { + // return CustomClaims containing jwt and authvariables. + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return &CustomClaims{}, nil + } + + jwtToken := md.Get(string(AuthJwtCtxKey)) + if len(jwtToken) == 0 { + return &CustomClaims{}, nil + } else if len(jwtToken) > 1 { + return nil, fmt.Errorf("invalid jwt auth token") + } + return validateJWTCustomClaims(jwtToken[0]) +} + +func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) { if metainfo.Algo == "" { return nil, fmt.Errorf( "jwt token cannot be validated because verification algorithm is not set") @@ -289,6 +291,5 @@ func validateToken(jwtStr string) (map[string]interface{}, error) { if err := claims.validateAudience(); err != nil { return nil, err } - - return claims.AuthVariables, nil + return claims, nil } diff --git a/graphql/e2e/auth/auth_test.go b/graphql/e2e/auth/auth_test.go index 4b4815166f1..cc21dcde262 100644 --- a/graphql/e2e/auth/auth_test.go +++ b/graphql/e2e/auth/auth_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "testing" + "time" "github.com/dgraph-io/dgraph/graphql/authorization" "github.com/dgraph-io/dgraph/graphql/e2e/common" @@ -177,7 +178,7 @@ func getJWT(t *testing.T, user, role string) http.Header { metaInfo.AuthVars["ROLE"] = role } - jwtToken, err := metaInfo.GetSignedToken("./sample_private_key.pem") + jwtToken, err := metaInfo.GetSignedToken("./sample_private_key.pem", 300*time.Second) require.NoError(t, err) h := make(http.Header) diff --git a/graphql/e2e/common/query.go b/graphql/e2e/common/query.go index bef7a84777a..19d625f2ebf 100644 --- a/graphql/e2e/common/query.go +++ b/graphql/e2e/common/query.go @@ -1284,9 +1284,9 @@ func querynestedOnlyTypename(t *testing.T) { }, { "__typename": "Post" - }, + }, { - + "__typename": "Post" } ] @@ -1311,8 +1311,8 @@ func onlytypenameForInterface(t *testing.T) { eq: [EMPIRE] } }) { - - + + ... on Human { __typename } @@ -1326,7 +1326,7 @@ func onlytypenameForInterface(t *testing.T) { expected := `{ "queryCharacter": [ { - "__typename": "Human" + "__typename": "Human" }, { "__typename": "Droid" diff --git a/graphql/e2e/common/subscription.go b/graphql/e2e/common/subscription.go index 6434fcbd349..c95ca5a8dbd 100644 --- a/graphql/e2e/common/subscription.go +++ b/graphql/e2e/common/subscription.go @@ -58,7 +58,7 @@ type GraphQLSubscriptionClient struct { } // NewGraphQLSubscription returns graphql subscription client. -func NewGraphQLSubscription(url string, req *schema.Request) (*GraphQLSubscriptionClient, error) { +func NewGraphQLSubscription(url string, req *schema.Request, subscriptionPayload string) (*GraphQLSubscriptionClient, error) { header := http.Header{ "Sec-WebSocket-Protocol": []string{protocolGraphQLWS}, } @@ -66,11 +66,10 @@ func NewGraphQLSubscription(url string, req *schema.Request) (*GraphQLSubscripti if err != nil { return nil, err } - // Initialize subscription. init := operationMessage{ Type: initMsg, - Payload: []byte(`{}`), + Payload: []byte(subscriptionPayload), } // Send Intialization message to the graphql server. diff --git a/graphql/e2e/custom_logic/custom_logic_test.go b/graphql/e2e/custom_logic/custom_logic_test.go index 1bd49a78b37..0e2c855681d 100644 --- a/graphql/e2e/custom_logic/custom_logic_test.go +++ b/graphql/e2e/custom_logic/custom_logic_test.go @@ -168,7 +168,7 @@ func TestCustomQueryShouldForwardHeaders(t *testing.T) { secretHeaders: ["Github-Api-Token"] }) } - + # Dgraph.Secret Github-Api-Token "random-fake-token" # Dgraph.Secret app "should-be-overriden" ` @@ -247,7 +247,7 @@ func TestSchemaIntrospectionForCustomQueryShouldForwardHeaders(t *testing.T) { code: String! name: String! } - + type Query { myCustom(yo: CountryInput!): [Country!]! @custom( @@ -331,7 +331,7 @@ func TestCustomFieldsInSubscription(t *testing.T) { name } }`, - }) + }, `{}`) require.NoError(t, err) _, err = client.RecvMsg() require.Contains(t, err.Error(), "Custom field `name` is not supported in graphql subscription") @@ -370,7 +370,7 @@ func TestSubscriptionInNestedCustomField(t *testing.T) { } } }`, - }) + }, `{}`) require.NoError(t, err) _, err = client.RecvMsg() require.Contains(t, err.Error(), "Custom field `anotherName` is not supported in graphql subscription") @@ -1257,7 +1257,7 @@ func TestCustomLogicGraphql(t *testing.T) { code: String name: String } - + type Query { getCountry1(id: ID!): Country! @custom( @@ -1292,7 +1292,7 @@ func TestCustomLogicGraphqlWithArgumentsOnFields(t *testing.T) { code(size: Int!): String name: String } - + type Query { getCountry2(id: ID!): Country! @custom( @@ -1378,7 +1378,7 @@ func TestCustomLogicGraphQLValidArrayResponse(t *testing.T) { code: String name: String } - + type Query { getCountries(id: ID!): [Country] @custom( @@ -2071,7 +2071,7 @@ func TestCustomGraphqlMissingRequiredArgument(t *testing.T) { code: String! name: String! } - + type Mutation { addCountry1(input: CountryInput!): Country! @custom(http: { url: "http://mock:8888/setCountry", @@ -2134,7 +2134,7 @@ func TestCustomGraphqlMutation1(t *testing.T) { code: String! name: String! } - + type Mutation { addCountry1(input: CountryInput!): Country! @custom(http: { url: "http://mock:8888/setCountry" @@ -2224,7 +2224,7 @@ func TestCustomGraphqlMutation2(t *testing.T) { code: String! name: String! } - + type Mutation { updateCountries(name: String, std: Int): [Country!]! @custom(http: { url: "http://mock:8888/updateCountries", @@ -2279,7 +2279,7 @@ func TestForValidInputArgument(t *testing.T) { code: String! name: String! } - + type Query { myCustom(yo: CountryInput!): [Country!]! @custom( diff --git a/graphql/e2e/subscription/subscription_test.go b/graphql/e2e/subscription/subscription_test.go index 5a05331b426..6fc60592e1e 100644 --- a/graphql/e2e/subscription/subscription_test.go +++ b/graphql/e2e/subscription/subscription_test.go @@ -18,11 +18,13 @@ package subscription_test import ( "encoding/json" + "fmt" "testing" "time" "github.com/dgraph-io/dgraph/graphql/e2e/common" "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/testutil" "github.com/stretchr/testify/require" ) @@ -30,6 +32,7 @@ const ( graphQLEndpoint = "http://localhost:8180/graphql" subscriptionEndpoint = "ws://localhost:8180/graphql" adminEndpoint = "http://localhost:8180/admin" + groupOnegRPC = "localhost:9180" sch = ` type Product @withSubscription { productID: ID! @@ -50,9 +53,30 @@ const ( rating: Int @search } ` + schAuth = ` + type Todo @withSubscription @auth( + query: { rule: """ + query ($USER: String!) { + queryTodo(filter: { owner: { eq: $USER } } ) { + __typename + } + }""" + } + ){ + id: ID! + text: String! @search(by: [term]) + owner: String! @search(by: [hash]) + } +# Dgraph.Authorization {"VerificationKey":"secret","Header":"Authorization","Namespace":"https://dgraph.io","Algo":"HS256"} +` ) func TestSubscription(t *testing.T) { + + dg, err := testutil.DgraphClient(groupOnegRPC) + require.NoError(t, err) + testutil.DropAll(t, dg) + add := &common.GraphQLParams{ Query: `mutation updateGQLSchema($sch: String!) { updateGQLSchema(input: { set: { schema: $sch }}) { @@ -87,7 +111,7 @@ func TestSubscription(t *testing.T) { name } }`, - }) + }, `{}`) require.Nil(t, err) res, err := subscriptionClient.RecvMsg() @@ -155,3 +179,612 @@ func TestSubscription(t *testing.T) { require.Nil(t, res) } + +func TestSubscriptionAuth(t *testing.T) { + dg, err := testutil.DgraphClient(groupOnegRPC) + require.NoError(t, err) + testutil.DropAll(t, dg) + + add := &common.GraphQLParams{ + Query: `mutation updateGQLSchema($sch: String!) { + updateGQLSchema(input: { set: { schema: $sch }}) { + gqlSchema { + schema + } + } + }`, + Variables: map[string]interface{}{"sch": schAuth}, + } + addResult := add.ExecuteAsPost(t, adminEndpoint) + require.Nil(t, addResult.Errors) + time.Sleep(time.Second * 2) + + metaInfo := &testutil.AuthMeta{ + PublicKey: "secret", + Namespace: "https://dgraph.io", + Algo: "HS256", + Header: "Authorization", + } + metaInfo.AuthVars = map[string]interface{}{ + "USER": "jatin", + "ROLE": "USER", + } + + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "GraphQL is exciting!!", + owner : "jatin"} + ]) + { + todo{ + text + owner + } + } + }`, + } + + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + jwtToken, err := metaInfo.GetSignedToken("secret", 100*time.Second) + require.NoError(t, err) + + payload := fmt.Sprintf(`{"Authorization": "%s"}`, jwtToken) + subscriptionClient, err := common.NewGraphQLSubscription(subscriptionEndpoint, &schema.Request{ + Query: `subscription{ + queryTodo{ + owner + text + } + }`, + }, payload) + require.Nil(t, err) + + res, err := subscriptionClient.RecvMsg() + require.NoError(t, err) + + var resp common.GraphQLResponse + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + + require.Nil(t, resp.Errors) + require.JSONEq(t, `{"queryTodo":[{"owner":"jatin","text":"GraphQL is exciting!!"}]}`, + string(resp.Data)) + + // Add a TODO for alice which should not be visible in the update because JWT belongs to + // Jatin + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "Dgraph is awesome!!", + owner : "alice"} + ]) + { + todo { + text + owner + } + } + }`, + } + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + // Add another TODO for jatin which we should get in the latest update. + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "Dgraph is awesome!!", + owner : "jatin"} + ]) + { + todo { + text + owner + } + } + }`, + } + + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + res, err = subscriptionClient.RecvMsg() + require.NoError(t, err) + + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + require.Nil(t, resp.Errors) + require.JSONEq(t, `{"queryTodo": [ + { + "owner": "jatin", + "text": "GraphQL is exciting!!" + }, + { + "owner" : "jatin", + "text" : "Dgraph is awesome!!" + }]}`, string(resp.Data)) + + // Change schema to terminate subscription.. + add = &common.GraphQLParams{ + Query: `mutation updateGQLSchema($sch: String!) { + updateGQLSchema(input: { set: { schema: $sch }}) { + gqlSchema { + schema + } + } + }`, + Variables: map[string]interface{}{"sch": sch}, + } + addResult = add.ExecuteAsPost(t, adminEndpoint) + require.Nil(t, addResult.Errors) + + res, err = subscriptionClient.RecvMsg() + require.NoError(t, err) + require.Nil(t, res) +} + +func TestSubscriptionWithAuthShouldExpireWithJWT(t *testing.T) { + dg, err := testutil.DgraphClient(groupOnegRPC) + require.NoError(t, err) + testutil.DropAll(t, dg) + + add := &common.GraphQLParams{ + Query: `mutation updateGQLSchema($sch: String!) { + updateGQLSchema(input: { set: { schema: $sch }}) { + gqlSchema { + schema + } + } + }`, + Variables: map[string]interface{}{"sch": schAuth}, + } + addResult := add.ExecuteAsPost(t, adminEndpoint) + require.Nil(t, addResult.Errors) + time.Sleep(time.Second * 2) + + metaInfo := &testutil.AuthMeta{ + PublicKey: "secret", + Namespace: "https://dgraph.io", + Algo: "HS256", + Header: "Authorization", + } + metaInfo.AuthVars = map[string]interface{}{ + "USER": "bob", + "ROLE": "USER", + } + + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "GraphQL is exciting!!", + owner : "bob"} + ]) + { + todo{ + text + owner + } + } + }`, + } + + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + jwtToken, err := metaInfo.GetSignedToken("secret", 10*time.Second) + require.NoError(t, err) + + payload := fmt.Sprintf(`{"Authorization": "%s"}`, jwtToken) + subscriptionClient, err := common.NewGraphQLSubscription(subscriptionEndpoint, + &schema.Request{ + Query: `subscription{ + queryTodo{ + owner + text + } + }`, + }, payload) + require.Nil(t, err) + + res, err := subscriptionClient.RecvMsg() + require.NoError(t, err) + + var resp common.GraphQLResponse + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + + require.Nil(t, resp.Errors) + require.JSONEq(t, `{"queryTodo":[{"owner":"bob","text":"GraphQL is exciting!!"}]}`, + string(resp.Data)) + + // Wait for JWT to expire. + time.Sleep(10 * time.Second) + + // Add another TODO for bob but this should not be visible as the subscription should have + // ended. + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "Dgraph is exciting!!", + owner : "bob"} + ]) + { + todo{ + text + owner + } + } + }`, + } + + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + res, err = subscriptionClient.RecvMsg() + require.NoError(t, err) + require.Nil(t, res) +} + +func TestSubscriptionAuth_SameQueryAndClaimsButDifferentExpiry_ShouldExpireIndependently(t *testing.T) { + dg, err := testutil.DgraphClient(groupOnegRPC) + require.NoError(t, err) + testutil.DropAll(t, dg) + + add := &common.GraphQLParams{ + Query: `mutation updateGQLSchema($sch: String!) { + updateGQLSchema(input: { set: { schema: $sch }}) { + gqlSchema { + schema + } + } + }`, + Variables: map[string]interface{}{"sch": schAuth}, + } + addResult := add.ExecuteAsPost(t, adminEndpoint) + require.Nil(t, addResult.Errors) + time.Sleep(time.Second * 2) + + metaInfo := &testutil.AuthMeta{ + PublicKey: "secret", + Namespace: "https://dgraph.io", + Algo: "HS256", + Header: "Authorization", + } + metaInfo.AuthVars = map[string]interface{}{ + "USER": "jatin", + "ROLE": "USER", + } + + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "GraphQL is exciting!!", + owner : "jatin"} + ]) + { + todo{ + text + owner + } + } + }`, + } + + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + jwtToken, err := metaInfo.GetSignedToken("secret", 10*time.Second) + require.NoError(t, err) + + // first subscription + payload := fmt.Sprintf(`{"Authorization": "%s"}`, jwtToken) + subscriptionClient, err := common.NewGraphQLSubscription(subscriptionEndpoint, &schema.Request{ + Query: `subscription{ + queryTodo{ + owner + text + } + }`, + }, payload) + require.Nil(t, err) + + res, err := subscriptionClient.RecvMsg() + require.NoError(t, err) + + var resp common.GraphQLResponse + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + require.Nil(t, resp.Errors) + require.JSONEq(t, `{"queryTodo":[{"owner":"jatin","text":"GraphQL is exciting!!"}]}`, + string(resp.Data)) + + // 2nd subscription + jwtToken, err = metaInfo.GetSignedToken("secret", 20*time.Second) + require.NoError(t, err) + payload = fmt.Sprintf(`{"Authorization": "%s"}`, jwtToken) + subscriptionClient1, err := common.NewGraphQLSubscription(subscriptionEndpoint, &schema.Request{ + Query: `subscription{ + queryTodo{ + owner + text + } + }`, + }, payload) + require.Nil(t, err) + + res, err = subscriptionClient1.RecvMsg() + require.NoError(t, err) + + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + + require.Nil(t, resp.Errors) + require.JSONEq(t, `{"queryTodo":[{"owner":"jatin","text":"GraphQL is exciting!!"}]}`, + string(resp.Data)) + + // Wait for JWT to expire for first subscription. + time.Sleep(10 * time.Second) + + // Add another TODO for jatin for which 1st subscription shouldn't get updates. + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "Dgraph is awesome!!", + owner : "jatin"} + ]) + { + todo { + text + owner + } + } + }`, + } + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + res, err = subscriptionClient.RecvMsg() + require.NoError(t, err) + require.Nil(t, res) // 1st subscription should get the empty response as subscription has expired. + + res, err = subscriptionClient1.RecvMsg() + require.NoError(t, err) + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + // 2nd one still running and should get the update + require.JSONEq(t, `{"queryTodo": [ + { + "owner": "jatin", + "text": "GraphQL is exciting!!" + }, + { + "owner" : "jatin", + "text" : "Dgraph is awesome!!" + }]}`, string(resp.Data)) + + // add extra delay for 2nd subscription to timeout + time.Sleep(10 * time.Second) + // Add another TODO for jatin for which 2nd subscription shouldn't get updates. + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "Graph Database is the future!!", + owner : "jatin"} + ]) + { + todo { + text + owner + } + } + }`, + } + + res, err = subscriptionClient1.RecvMsg() + require.NoError(t, err) + require.Nil(t, res) // 2nd subscription should get the empty response as subscription has expired. +} + +func TestSubscriptionAuth_SameQueryDifferentClaimsAndExpiry_ShouldExpireIndependently(t *testing.T) { + dg, err := testutil.DgraphClient(groupOnegRPC) + require.NoError(t, err) + testutil.DropAll(t, dg) + + add := &common.GraphQLParams{ + Query: `mutation updateGQLSchema($sch: String!) { + updateGQLSchema(input: { set: { schema: $sch }}) { + gqlSchema { + schema + } + } + }`, + Variables: map[string]interface{}{"sch": schAuth}, + } + addResult := add.ExecuteAsPost(t, adminEndpoint) + require.Nil(t, addResult.Errors) + time.Sleep(time.Second * 2) + + metaInfo := &testutil.AuthMeta{ + PublicKey: "secret", + Namespace: "https://dgraph.io", + Algo: "HS256", + Header: "Authorization", + } + metaInfo.AuthVars = map[string]interface{}{ + "USER": "jatin", + "ROLE": "USER", + } + // for user jatin + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "GraphQL is exciting!!", + owner : "jatin"} + ]) + { + todo{ + text + owner + } + } + }`, + } + + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + jwtToken, err := metaInfo.GetSignedToken("secret", 10*time.Second) + require.NoError(t, err) + + // first subscription + payload := fmt.Sprintf(`{"Authorization": "%s"}`, jwtToken) + subscriptionClient, err := common.NewGraphQLSubscription(subscriptionEndpoint, &schema.Request{ + Query: `subscription{ + queryTodo{ + owner + text + } + }`, + }, payload) + require.Nil(t, err) + + res, err := subscriptionClient.RecvMsg() + require.NoError(t, err) + + var resp common.GraphQLResponse + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + + require.Nil(t, resp.Errors) + require.JSONEq(t, `{"queryTodo":[{"owner":"jatin","text":"GraphQL is exciting!!"}]}`, + string(resp.Data)) + + // for user pawan + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "GraphQL is exciting!!", + owner : "pawan"} + ]) + { + todo { + text + owner + } + } + }`, + } + + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + // 2nd subscription + metaInfo.AuthVars["USER"] = "pawan" + jwtToken, err = metaInfo.GetSignedToken("secret", 20*time.Second) + require.NoError(t, err) + payload = fmt.Sprintf(`{"Authorization": "%s"}`, jwtToken) + subscriptionClient1, err := common.NewGraphQLSubscription(subscriptionEndpoint, &schema.Request{ + Query: `subscription{ + queryTodo{ + owner + text + } + }`, + }, payload) + require.Nil(t, err) + + res, err = subscriptionClient1.RecvMsg() + require.NoError(t, err) + + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + + require.Nil(t, resp.Errors) + require.JSONEq(t, `{"queryTodo":[{"owner":"pawan","text":"GraphQL is exciting!!"}]}`, + string(resp.Data)) + + // Wait for JWT to expire for 1st subscription. + time.Sleep(10 * time.Second) + + // Add another TODO for jatin for which 1st subscription shouldn't get updates. + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "Dgraph is awesome!!", + owner : "jatin"} + ]) + { + todo { + text + owner + } + } + }`, + } + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + require.NoError(t, err) + // 1st subscription should get the empty response as subscription has expired + res, err = subscriptionClient.RecvMsg() + require.NoError(t, err) + require.Nil(t, res) + + // Add another TODO for pawan which we should get in the latest update of 2nd subscription. + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "Dgraph is awesome!!", + owner : "pawan"} + ]) + { + todo { + text + owner + } + } + }`, + } + addResult = add.ExecuteAsPost(t, graphQLEndpoint) + require.Nil(t, addResult.Errors) + + res, err = subscriptionClient1.RecvMsg() + require.NoError(t, err) + err = json.Unmarshal(res, &resp) + require.NoError(t, err) + // 2nd one still running and should get the update + require.JSONEq(t, `{"queryTodo": [ + { + "owner": "pawan", + "text": "GraphQL is exciting!!" + }, + { + "owner" : "pawan", + "text" : "Dgraph is awesome!!" + }]}`, string(resp.Data)) + + // add delay for 2nd subscription to timeout + // Wait for JWT to expire. + time.Sleep(10 * time.Second) + // Add another TODO for pawan for which 2nd subscription shouldn't get updates. + add = &common.GraphQLParams{ + Query: `mutation{ + addTodo(input: [ + {text : "Graph Database is the future!!", + owner : "pawan"} + ]) + { + todo { + text + owner + } + } + }`, + } + + // 2nd subscription should get the empty response as subscriptio has expired + res, err = subscriptionClient1.RecvMsg() + require.NoError(t, err) + require.Nil(t, res) +} diff --git a/graphql/resolve/auth_test.go b/graphql/resolve/auth_test.go index 358cf6f0b76..294606d742c 100644 --- a/graphql/resolve/auth_test.go +++ b/graphql/resolve/auth_test.go @@ -169,9 +169,9 @@ func TestStringCustomClaim(t *testing.T) { md := metadata.New(map[string]string{"authorizationJwt": token}) ctx := metadata.NewIncomingContext(context.Background(), md) - authVar, err := authorization.ExtractAuthVariables(ctx) + customClaims, err := authorization.ExtractCustomClaims(ctx) require.NoError(t, err) - + authVar := customClaims.AuthVariables result := map[string]interface{}{ "ROLE": "ADMIN", "USER": "50950b40-262f-4b26-88a7-cbbb780b2176", @@ -225,13 +225,13 @@ func TestAudienceClaim(t *testing.T) { md := metadata.New(map[string]string{"authorizationJwt": tcase.token}) ctx := metadata.NewIncomingContext(context.Background(), md) - authVar, err := authorization.ExtractAuthVariables(ctx) + customClaims, err := authorization.ExtractCustomClaims(ctx) require.Equal(t, tcase.err, err) - if err != nil { return } + authVar := customClaims.AuthVariables result := map[string]interface{}{ "ROLE": "ADMIN", "USER": "50950b40-262f-4b26-88a7-cbbb780b2176", diff --git a/graphql/resolve/mutation.go b/graphql/resolve/mutation.go index e17927c220b..ae888bdc53a 100644 --- a/graphql/resolve/mutation.go +++ b/graphql/resolve/mutation.go @@ -375,10 +375,11 @@ func authorizeNewNodes( queryExecutor DgraphExecutor, txn *dgoapi.TxnContext) error { - authVariables, err := authorization.ExtractAuthVariables(ctx) + customClaims, err := authorization.ExtractCustomClaims(ctx) if err != nil { return schema.GQLWrapf(err, "authorization failed") } + authVariables := customClaims.AuthVariables newRw := &authRewriter{ authVariables: authVariables, varGen: NewVariableGenerator(), diff --git a/graphql/resolve/mutation_rewriter.go b/graphql/resolve/mutation_rewriter.go index c9399810347..275a0b4db73 100644 --- a/graphql/resolve/mutation_rewriter.go +++ b/graphql/resolve/mutation_rewriter.go @@ -357,12 +357,13 @@ func (mrw *AddRewriter) FromMutationResult( errs = schema.AsGQLErrors(errors.Errorf("no new node was created")) } - authVariables, err := authorization.ExtractAuthVariables(ctx) + customClaims, err := authorization.ExtractCustomClaims(ctx) if err != nil { return nil, err } + authRw := &authRewriter{ - authVariables: authVariables, + authVariables: customClaims.AuthVariables, varGen: NewVariableGenerator(), selector: queryAuthSelector, parentVarName: mutation.MutatedType().Name() + "Root", @@ -410,12 +411,13 @@ func (urw *UpdateRewriter) Rewrite( varGen := NewVariableGenerator() - authVariables, err := authorization.ExtractAuthVariables(ctx) + customClaims, err := authorization.ExtractCustomClaims(ctx) if err != nil { return nil, err } + authRw := &authRewriter{ - authVariables: authVariables, + authVariables: customClaims.AuthVariables, varGen: varGen, selector: updateAuthSelector, parentVarName: m.MutatedType().Name() + "Root", @@ -556,12 +558,13 @@ func (urw *UpdateRewriter) FromMutationResult( } } - authVariables, err := authorization.ExtractAuthVariables(ctx) + customClaims, err := authorization.ExtractCustomClaims(ctx) if err != nil { return nil, err } + authRw := &authRewriter{ - authVariables: authVariables, + authVariables: customClaims.AuthVariables, varGen: NewVariableGenerator(), selector: queryAuthSelector, parentVarName: mutation.MutatedType().Name() + "Root", @@ -679,13 +682,13 @@ func (drw *deleteRewriter) Rewrite( varGen := NewVariableGenerator() - authVariables, err := authorization.ExtractAuthVariables(ctx) + customClaims, err := authorization.ExtractCustomClaims(ctx) if err != nil { return nil, err } authRw := &authRewriter{ - authVariables: authVariables, + authVariables: customClaims.AuthVariables, varGen: varGen, selector: deleteAuthSelector, parentVarName: m.MutatedType().Name() + "Root", @@ -743,7 +746,7 @@ func (drw *deleteRewriter) Rewrite( // is later added to delete mutation result. if queryField := m.QueryField(); queryField.SelectionSet() != nil { queryAuthRw := &authRewriter{ - authVariables: authVariables, + authVariables: customClaims.AuthVariables, varGen: varGen, selector: queryAuthSelector, filterByUid: true, @@ -1495,14 +1498,15 @@ func addDelete( // then we need update permission on Author1 // grab the auth for Author1 - authVariables, err := authorization.ExtractAuthVariables(ctx) + customClaims, err := authorization.ExtractCustomClaims(ctx) if err != nil { frag.check = checkQueryResult("auth.failed", nil, schema.GQLWrapf(err, "authorization failed")) return } + newRw := &authRewriter{ - authVariables: authVariables, + authVariables: customClaims.AuthVariables, varGen: varGen, varName: targetVar, selector: updateAuthSelector, diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index 850266c76e6..3e17592f998 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -76,9 +76,14 @@ func (qr *queryRewriter) Rewrite( return &gql.GraphQuery{Attr: gqlQuery.ResponseName() + "()"}, nil } - authVariables, err := authorization.ExtractAuthVariables(ctx) - if err != nil { - return nil, err + authVariables, _ := ctx.Value(authorization.AuthVariables).(map[string]interface{}) + + if authVariables == nil { + customClaims, err := authorization.ExtractCustomClaims(ctx) + if err != nil { + return nil, err + } + authVariables = customClaims.AuthVariables } authRw := &authRewriter{ diff --git a/graphql/subscription/poller.go b/graphql/subscription/poller.go index 52167dcea60..3e806a39dc9 100644 --- a/graphql/subscription/poller.go +++ b/graphql/subscription/poller.go @@ -24,6 +24,7 @@ import ( "sync/atomic" "time" + "github.com/dgraph-io/dgraph/graphql/authorization" "github.com/dgraph-io/dgraph/graphql/resolve" "github.com/dgraph-io/dgraph/graphql/schema" "github.com/dgraph-io/dgraph/x" @@ -35,7 +36,7 @@ import ( type Poller struct { sync.Mutex resolver *resolve.RequestResolver - pollRegistry map[uint64]map[uint64]chan interface{} + pollRegistry map[uint64]map[uint64]subscriber subscriptionID uint64 globalEpoch *uint64 } @@ -44,7 +45,7 @@ type Poller struct { func NewPoller(globalEpoch *uint64, resolver *resolve.RequestResolver) *Poller { return &Poller{ resolver: resolver, - pollRegistry: make(map[uint64]map[uint64]chan interface{}), + pollRegistry: make(map[uint64]map[uint64]subscriber), globalEpoch: globalEpoch, } } @@ -56,11 +57,17 @@ type SubscriberResponse struct { UpdateCh chan interface{} } +type subscriber struct { + expiry time.Time + updateCh chan interface{} +} + // AddSubscriber tries to add subscription into the existing polling goroutine if it exists. // If it doesn't exist, then it creates a new polling goroutine for the given request. -func (p *Poller) AddSubscriber(req *schema.Request) (*SubscriberResponse, error) { - localEpoch := atomic.LoadUint64(p.globalEpoch) +func (p *Poller) AddSubscriber( + req *schema.Request, customClaims *authorization.CustomClaims) (*SubscriberResponse, error) { + localEpoch := atomic.LoadUint64(p.globalEpoch) err := p.resolver.ValidateSubscription(req) if err != nil { return nil, err @@ -68,12 +75,23 @@ func (p *Poller) AddSubscriber(req *schema.Request) (*SubscriberResponse, error) buf, err := json.Marshal(req) x.Check(err) + var bucketID uint64 + if customClaims.AuthVariables != nil { - bucketID := farm.Fingerprint64(buf) + // TODO - Add custom marshal function that marshal's the json in sorted order. + authvariables, err := json.Marshal(customClaims.AuthVariables) + if err != nil { + return nil, err + } + bucketID = farm.Fingerprint64(append(buf, authvariables...)) + } else { + bucketID = farm.Fingerprint64(buf) + } p.Lock() defer p.Unlock() - res := p.resolver.Resolve(context.TODO(), req) + ctx := context.WithValue(context.Background(), authorization.AuthVariables, customClaims.AuthVariables) + res := p.resolver.Resolve(ctx, req) if len(res.Errors) != 0 { return nil, res.Errors } @@ -88,10 +106,12 @@ func (p *Poller) AddSubscriber(req *schema.Request) (*SubscriberResponse, error) p.subscriptionID++ subscriptions, ok := p.pollRegistry[bucketID] if !ok { - subscriptions = make(map[uint64]chan interface{}) + subscriptions = make(map[uint64]subscriber) } glog.Infof("Subscription polling is started for the ID %d", subscriptionID) - subscriptions[subscriptionID] = updateCh + + subscriptions[subscriptionID] = subscriber{ + expiry: customClaims.StandardClaims.ExpiresAt.Time, updateCh: updateCh} p.pollRegistry[bucketID] = subscriptions if len(subscriptions) != 1 { @@ -101,32 +121,34 @@ func (p *Poller) AddSubscriber(req *schema.Request) (*SubscriberResponse, error) return &SubscriberResponse{ BucketID: bucketID, SubscriptionID: subscriptionID, - UpdateCh: updateCh, + UpdateCh: subscriptions[subscriptionID].updateCh, }, nil } // There is no goroutine running to check updates for this query. So, run one to publish // the updates. pollR := &pollRequest{ - bucketID: bucketID, - prevHash: prevHash, - graphqlReq: req, - localEpoch: localEpoch, + bucketID: bucketID, + prevHash: prevHash, + graphqlReq: req, + authVariables: customClaims.AuthVariables, + localEpoch: localEpoch, } go p.poll(pollR) return &SubscriberResponse{ BucketID: bucketID, SubscriptionID: subscriptionID, - UpdateCh: updateCh, + UpdateCh: subscriptions[subscriptionID].updateCh, }, nil } type pollRequest struct { - prevHash uint64 - graphqlReq *schema.Request - bucketID uint64 - localEpoch uint64 + prevHash uint64 + graphqlReq *schema.Request + bucketID uint64 + localEpoch uint64 + authVariables map[string]interface{} } func (p *Poller) poll(req *pollRequest) { @@ -142,10 +164,10 @@ func (p *Poller) poll(req *pollRequest) { // We'll terminate all the subscription for this bucket. So, that all client can // reconnect and listen for new schema. p.terminateSubscriptions(req.bucketID) - return } - res := resolver.Resolve(context.TODO(), req.graphqlReq) + ctx := context.WithValue(context.Background(), authorization.AuthVariables, req.authVariables) + res := resolver.Resolve(ctx, req.graphqlReq) currentHash := farm.Fingerprint64(res.Data.Bytes()) @@ -162,6 +184,12 @@ func (p *Poller) poll(req *pollRequest) { p.Unlock() return } + for subscriberID, subscriber := range subscribers { + if !subscriber.expiry.IsZero() && time.Now().After(subscriber.expiry) { + p.terminateSubscription(req.bucketID, subscriberID) + } + + } p.Unlock() continue } @@ -175,8 +203,15 @@ func (p *Poller) poll(req *pollRequest) { p.Unlock() return } - for _, updateCh := range subscribers { - updateCh <- res.Output() + + for subscriberID, subscriber := range subscribers { + if !subscriber.expiry.IsZero() && time.Now().After(subscriber.expiry) { + p.terminateSubscription(req.bucketID, subscriberID) + } + + } + for _, subscriber := range subscribers { + subscriber.updateCh <- res.Output() } p.Unlock() } @@ -197,25 +232,29 @@ func (p *Poller) terminateSubscriptions(bucketID uint64) { if !ok { return } - for _, updateCh := range subscriptions { + for _, subscriber := range subscriptions { // Closing the channel will close the graphQL websocket connection as well. - close(updateCh) + close(subscriber.updateCh) } delete(p.pollRegistry, bucketID) } -// TerminateSubscription will terminate the polling subscription. func (p *Poller) TerminateSubscription(bucketID, subscriptionID uint64) { p.Lock() defer p.Unlock() + p.terminateSubscription(bucketID, subscriptionID) +} + +func (p *Poller) terminateSubscription(bucketID, subscriptionID uint64) { subscriptions, ok := p.pollRegistry[bucketID] if !ok { return } - updateCh, ok := subscriptions[subscriptionID] + subscriber, ok := subscriptions[subscriptionID] if ok { glog.Infof("Terminating subscription for the subscription ID %d", subscriptionID) - close(updateCh) + close(subscriber.updateCh) + } delete(subscriptions, subscriptionID) p.pollRegistry[bucketID] = subscriptions diff --git a/graphql/web/http.go b/graphql/web/http.go index 9c7e112f47b..b53f76b6d6c 100644 --- a/graphql/web/http.go +++ b/graphql/web/http.go @@ -21,6 +21,10 @@ import ( "context" "encoding/json" "strconv" + "time" + + "github.com/dgrijalva/jwt-go/v4" + "google.golang.org/grpc/metadata" "io" "io/ioutil" @@ -34,15 +38,19 @@ import ( "github.com/dgraph-io/dgraph/graphql/schema" "github.com/dgraph-io/dgraph/graphql/subscription" "github.com/dgraph-io/dgraph/x" + "github.com/dgraph-io/graphql-transport-ws/graphqlws" "github.com/golang/glog" - "github.com/graph-gophers/graphql-transport-ws/graphqlws" "github.com/pkg/errors" "go.opencensus.io/trace" ) -const touchedUidsHeader = "Graphql-TouchedUids" +type Headerkey string + +const ( + touchedUidsHeader = "Graphql-TouchedUids" +) -// An IServeGraphQL can serve a GraphQL endpoint (currently only on http) +// An IServeGraphQL can serve a GraphQL endpoint (currently only ons http) type IServeGraphQL interface { // After ServeGQL is called, this IServeGraphQL serves the new resolvers. @@ -116,12 +124,45 @@ func (gs *graphqlSubscription) Subscribe( operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error) { + + // library (graphql-transport-ws) passes the headers which are part of the INIT payload to us in the context. + // And we are extracting the Auth JWT from those and passing them along. + + header, _ := ctx.Value("Header").(json.RawMessage) + customClaims := &authorization.CustomClaims{ + StandardClaims: jwt.StandardClaims{ + ExpiresAt: jwt.At(time.Time{}), + }, + } + if len(header) > 0 { + payload := make(map[string]interface{}) + if err := json.Unmarshal(header, &payload); err != nil { + return nil, err + } + + name := authorization.GetHeader() + val, ok := payload[name].(string) + if ok { + + md := metadata.New(map[string]string{ + "authorizationJwt": val, + }) + ctx = metadata.NewIncomingContext(ctx, md) + + customClaims, err = authorization.ExtractCustomClaims(ctx) + if err != nil { + return nil, err + } + } + } + req := &schema.Request{ OperationName: operationName, Query: document, Variables: variableValues, } - res, err := gs.graphqlHandler.poller.AddSubscriber(req) + + res, err := gs.graphqlHandler.poller.AddSubscriber(req, customClaims) if err != nil { return nil, err } diff --git a/testutil/graphql.go b/testutil/graphql.go index 4024f2faf83..1c17260046c 100644 --- a/testutil/graphql.go +++ b/testutil/graphql.go @@ -152,12 +152,13 @@ type AuthMeta struct { AuthVars map[string]interface{} } -func (a *AuthMeta) GetSignedToken(privateKeyFile string) (string, error) { +func (a *AuthMeta) GetSignedToken(privateKeyFile string, + expireAfter time.Duration) (string, error) { claims := clientCustomClaims{ a.Namespace, a.AuthVars, jwt.StandardClaims{ - ExpiresAt: jwt.At(time.Now().Add(time.Minute * 5)), + ExpiresAt: jwt.At(time.Now().Add(expireAfter)), Issuer: "test", }, } @@ -188,7 +189,7 @@ func (a *AuthMeta) GetSignedToken(privateKeyFile string) (string, error) { } func (a *AuthMeta) AddClaimsToContext(ctx context.Context) (context.Context, error) { - token, err := a.GetSignedToken("../e2e/auth/sample_private_key.pem") + token, err := a.GetSignedToken("../e2e/auth/sample_private_key.pem", 5*time.Minute) if err != nil { return ctx, err }