diff --git a/graphql/dgraph/graphquery.go b/graphql/dgraph/graphquery.go index 08eb3f22fba..9d541578b3d 100644 --- a/graphql/dgraph/graphquery.go +++ b/graphql/dgraph/graphquery.go @@ -69,6 +69,10 @@ func writeQuery(b *strings.Builder, query *gql.GraphQuery, prefix string, root b x.Check2(b.WriteRune(')')) } + if query.Cascade { + x.Check2(b.WriteString(" @cascade")) + } + switch { case len(query.Children) > 0: prefixAdd := "" diff --git a/graphql/e2e/auth/schema.graphql b/graphql/e2e/auth/schema.graphql index 53b6ae5b861..a5e7456cb43 100644 --- a/graphql/e2e/auth/schema.graphql +++ b/graphql/e2e/auth/schema.graphql @@ -1,101 +1,144 @@ -type User @auth( - query: { or : [ - { rule: "filter: {isPublic: true}" }, - { rule: "filter: {username: {eq: $USER }}"}, - ]}, - add: { rule: "$ROLE: { eq: ADD-BOT }" }, - update: { or: [ - { rule: "$ROLE: { eq: ADMIN }" }, - { rule: "(filter: { username: { eq: $USER } })"} - ]} - delete: { rule: "false" }) { +type User { +# @auth( +# add: { rule: "$ROLE: { eq: ADD-BOT }" }, +# update: { or: [ +# { rule: "$ROLE: { eq: ADMIN }" }, +# { rule: "(filter: { username: { eq: $USER } })"} +# ]} +# delete: { rule: "false" }) { username: String! @id - age: Int @auth(query: {rule: "filter: {username: {eq: $USER }}"}) + age: Int # @auth(query: {rule: "filter: {username: {eq: $USER }}"}) isPublic: Boolean - disabled: Boolean @auth(update: { rule: "$ROLE: { eq: ADMIN }" }) + disabled: Boolean # @auth(update: { rule: "$ROLE: { eq: ADMIN }" }) + tickets: [Ticket] @hasInverse(field: assignedTo) + secrets: [UserSecret] +} + +type UserSecret @auth( + query: { rule: """ + query($USER: String!) { + queryUserSecret(filter: { ownedBy: { eq: $USER } }) { + __typename + } + } + """} +){ + id: ID! + ownedBy: String @search(by: [hash]) } type Region { id: ID! name: String + global: Boolean @search users: [User] } type Movie @auth( + # You can query a movie if + # - it's not hidden + # AND + # - you are in a region it's available OR it's globally available query: { and: [ - {rule: """filter: {disabled: false}"""}, - {rule: """regionsAvailable { users ( filter: {username: {eq: $USER}})}"""}, - {not: {rule: """regionsAvailable { users ( filter: {disabled: {eq: true}})}"""}}, + { not: { rule: """ + query { + queryMovie(filter: { hidden: true }) { __typename } + } + """}}, + { or: [ + { rule: """ + query($USER: String!) { + queryMovie { + regionsAvailable { + users(filter: {username: {eq: $USER}}) { + __typename + } + } + } + }""" + }, + { rule: """ + query { + queryMovie { + regionsAvailable(filter: { global: true }) { + __typename + } + } + }""" + } + ]} ]} - - add: { rule: "$ROLE: { eq: ADMIN }" } - update: { rule: "$ROLE: { eq: ADMIN }" } - delete: { rule: "$ROLE: { eq: ADMIN }" } +# add: { rule: "$ROLE: { eq: ADMIN }" } +# update: { rule: "$ROLE: { eq: ADMIN }" } +# delete: { rule: "$ROLE: { eq: ADMIN }" } ) { content: String - disabled: Boolean + hidden: Boolean @search regionsAvailable: [Region] } -type Issue @auth( - query: { and: [ - {rule: "$ROLE: { eq: ADMIN }"}, - {rule: "owner(filter: { username: { eq: $USER } })"}, - ]} - add: { and: [ - {rule: "$ROLE: { eq: ADMIN }"}, - {rule: "owner(filter: { username: { eq: $USER } })"}, - ]} - update: { and: [ - {rule: "$ROLE: { eq: ADMIN }"}, - {rule: "owner(filter: { username: { eq: $USER } })"}, - ]} - delete: { and: [ - {rule: "$ROLE: { eq: ADMIN }"}, - {rule: "owner(filter: { username: { eq: $USER } })"}, - ]} -) { +type Issue { +# @auth( +# query: { and: [ +# {rule: "$ROLE: { eq: ADMIN }"}, +# {rule: "owner(filter: { username: { eq: $USER } })"}, +# ]} +# add: { and: [ +# {rule: "$ROLE: { eq: ADMIN }"}, +# {rule: "owner(filter: { username: { eq: $USER } })"}, +# ]} +# update: { and: [ +# {rule: "$ROLE: { eq: ADMIN }"}, +# {rule: "owner(filter: { username: { eq: $USER } })"}, +# ]} +# delete: { and: [ +# {rule: "$ROLE: { eq: ADMIN }"}, +# {rule: "owner(filter: { username: { eq: $USER } })"}, +# ]} +# ) { id: ID! msg: String owner: User! } -type Log @auth( - query: { rule: "$ROLE: { eq: ADMIN }" } - add: { rule: "$ROLE: { eq: ADMIN }" } - update: { rule: "$ROLE: { eq: ADMIN }" } - delete: { rule: "$ROLE: { eq: ADMIN }" } -) { +type Log { +# @auth( +# query: { rule: "$ROLE: { eq: ADMIN }" } +# add: { rule: "$ROLE: { eq: ADMIN }" } +# update: { rule: "$ROLE: { eq: ADMIN }" } +# delete: { rule: "$ROLE: { eq: ADMIN }" } +# ) { id: ID! logs: String } -type Project @auth( - query: { or: [ - { rule: """roles(filter: { permissions: { eq: VIEW } }) { - assignedTo(filter: { username: { eq: $USER } }) - }""" }, - { rule: "$ROLE: { eq: ADMIN }" } - ]} +type Project { +# @auth( +# query: { or: [ +# { rule: """roles(filter: { permissions: { eq: VIEW } }) { +# assignedTo(filter: { username: { eq: $USER } }) +# }""" }, +# { rule: "$ROLE: { eq: ADMIN }" } +# ]} - # Only admins can create projects - add: { rule: "$ROLE: { eq: ADMIN }" } +# # Only admins can create projects +# add: { rule: "$ROLE: { eq: ADMIN }" } - update: { rule: """roles(filter: { permissions: { eq: CREATE } }) { - assignedTo(filter: { username: { eq: $USER } }) - }""" } +# update: { rule: """roles(filter: { permissions: { eq: CREATE } }) { +# assignedTo(filter: { username: { eq: $USER } }) +# }""" } - delete: { rule: "false" } -) { +# delete: { rule: "false" } +# ) { projID: ID! name: String! roles: [Role] - columns: [Column] @hasInverse(field: inProject) @auth(add: {rule: "DENY"}) + columns: [Column] @hasInverse(field: inProject) # @auth(add: {rule: "DENY"}) } type Role { id: ID! - permissions: [Permission] + permission: Permission @search assignedTo: [User] } @@ -105,62 +148,70 @@ enum Permission { ADMIN } -type Column @auth( - query: { rule: """inProject { - role(filter: { permission: { eq: VIEW } } ) { - users(filter: { username: { eq: $USER } }) - } - }"""}, - add: { rule: """inProject { - role(filter: { permission: { eq: ADMIN } } ) { - users(filter: { username: { eq: $USER } }) - } - }"""}, - update: { rule: """inProject { - role(filter: { permission: { eq: EDIT } } ) { - users(filter: { username: { eq: $USER } }) - } - }"""}, - delete: { rule: "false" } -) { +type Column { +# @auth( +# query: { rule: """inProject { +# role(filter: { permission: { eq: VIEW } } ) { +# users(filter: { username: { eq: $USER } }) +# } +# }"""}, +# add: { rule: """inProject { +# role(filter: { permission: { eq: ADMIN } } ) { +# users(filter: { username: { eq: $USER } }) +# } +# }"""}, +# update: { rule: """inProject { +# role(filter: { permission: { eq: EDIT } } ) { +# users(filter: { username: { eq: $USER } }) +# } +# }"""}, +# delete: { rule: "false" } +# ) { colID: ID! - inProject: Project! @auth(update: { rule: "DENY" }) + inProject: Project! # @auth(update: { rule: "DENY" }) name: String! tickets: [Ticket] @hasInverse(field: onColumn) } type Ticket @auth( - query: { rule: """onColumn { - inProject { - role(filter: { permission: { eq: VIEW } } ) { - users(filter: { username: { eq: $USER } }) + query: { rule: """ + query($USER: String!) { + queryTicket { + onColumn{ + inProject { + roles(filter: { permission: { eq: VIEW } } ) { + assignedTo(filter: { username: { eq: $USER } }) { + __typename } } - }"""}, - add: { rule: """onColumn { - inProject { - role(filter: { permission: { eq: WRITE } } ) { - users(filter: { username: { eq: $USER } }) - } - } - }"""}, - update: { rule: """onColumn { - inProject { - role(filter: { permission: { eq: WRITE } } ) { - users(filter: { username: { eq: $USER } }) - } - } - }"""}, - delete: { rule: """onColumn { - inProject { - role(filter: { permission: { eq: WRITE } } ) { - users(filter: { username: { eq: $USER } }) - } - } - }"""} + } + } + } + }"""} + # add: { rule: """onColumn { + # inProject { + # role(filter: { permission: { eq: WRITE } } ) { + # users(filter: { username: { eq: $USER } }) + # } + # } + # }"""}, + # update: { rule: """onColumn { + # inProject { + # role(filter: { permission: { eq: WRITE } } ) { + # users(filter: { username: { eq: $USER } }) + # } + # } + # }"""}, + # delete: { rule: """onColumn { + # inProject { + # role(filter: { permission: { eq: WRITE } } ) { + # users(filter: { username: { eq: $USER } }) + # } + # } + # }"""} ){ id: ID! onColumn: Column! - title: String! + title: String! @search(by: [term]) assignedTo: [User!] } diff --git a/graphql/resolve/auth-schema.graphql b/graphql/resolve/auth-schema.graphql deleted file mode 100644 index 53b6ae5b861..00000000000 --- a/graphql/resolve/auth-schema.graphql +++ /dev/null @@ -1,166 +0,0 @@ -type User @auth( - query: { or : [ - { rule: "filter: {isPublic: true}" }, - { rule: "filter: {username: {eq: $USER }}"}, - ]}, - add: { rule: "$ROLE: { eq: ADD-BOT }" }, - update: { or: [ - { rule: "$ROLE: { eq: ADMIN }" }, - { rule: "(filter: { username: { eq: $USER } })"} - ]} - delete: { rule: "false" }) { - username: String! @id - age: Int @auth(query: {rule: "filter: {username: {eq: $USER }}"}) - isPublic: Boolean - disabled: Boolean @auth(update: { rule: "$ROLE: { eq: ADMIN }" }) -} - -type Region { - id: ID! - name: String - users: [User] -} - -type Movie @auth( - query: { and: [ - {rule: """filter: {disabled: false}"""}, - {rule: """regionsAvailable { users ( filter: {username: {eq: $USER}})}"""}, - {not: {rule: """regionsAvailable { users ( filter: {disabled: {eq: true}})}"""}}, - ]} - - add: { rule: "$ROLE: { eq: ADMIN }" } - update: { rule: "$ROLE: { eq: ADMIN }" } - delete: { rule: "$ROLE: { eq: ADMIN }" } -) { - content: String - disabled: Boolean - regionsAvailable: [Region] -} - -type Issue @auth( - query: { and: [ - {rule: "$ROLE: { eq: ADMIN }"}, - {rule: "owner(filter: { username: { eq: $USER } })"}, - ]} - add: { and: [ - {rule: "$ROLE: { eq: ADMIN }"}, - {rule: "owner(filter: { username: { eq: $USER } })"}, - ]} - update: { and: [ - {rule: "$ROLE: { eq: ADMIN }"}, - {rule: "owner(filter: { username: { eq: $USER } })"}, - ]} - delete: { and: [ - {rule: "$ROLE: { eq: ADMIN }"}, - {rule: "owner(filter: { username: { eq: $USER } })"}, - ]} -) { - id: ID! - msg: String - owner: User! -} - -type Log @auth( - query: { rule: "$ROLE: { eq: ADMIN }" } - add: { rule: "$ROLE: { eq: ADMIN }" } - update: { rule: "$ROLE: { eq: ADMIN }" } - delete: { rule: "$ROLE: { eq: ADMIN }" } -) { - id: ID! - logs: String -} - -type Project @auth( - query: { or: [ - { rule: """roles(filter: { permissions: { eq: VIEW } }) { - assignedTo(filter: { username: { eq: $USER } }) - }""" }, - { rule: "$ROLE: { eq: ADMIN }" } - ]} - - # Only admins can create projects - add: { rule: "$ROLE: { eq: ADMIN }" } - - update: { rule: """roles(filter: { permissions: { eq: CREATE } }) { - assignedTo(filter: { username: { eq: $USER } }) - }""" } - - delete: { rule: "false" } -) { - projID: ID! - name: String! - roles: [Role] - columns: [Column] @hasInverse(field: inProject) @auth(add: {rule: "DENY"}) -} - -type Role { - id: ID! - permissions: [Permission] - assignedTo: [User] -} - -enum Permission { - VIEW - EDIT - ADMIN -} - -type Column @auth( - query: { rule: """inProject { - role(filter: { permission: { eq: VIEW } } ) { - users(filter: { username: { eq: $USER } }) - } - }"""}, - add: { rule: """inProject { - role(filter: { permission: { eq: ADMIN } } ) { - users(filter: { username: { eq: $USER } }) - } - }"""}, - update: { rule: """inProject { - role(filter: { permission: { eq: EDIT } } ) { - users(filter: { username: { eq: $USER } }) - } - }"""}, - delete: { rule: "false" } -) { - colID: ID! - inProject: Project! @auth(update: { rule: "DENY" }) - name: String! - tickets: [Ticket] @hasInverse(field: onColumn) -} - -type Ticket @auth( - query: { rule: """onColumn { - inProject { - role(filter: { permission: { eq: VIEW } } ) { - users(filter: { username: { eq: $USER } }) - } - } - }"""}, - add: { rule: """onColumn { - inProject { - role(filter: { permission: { eq: WRITE } } ) { - users(filter: { username: { eq: $USER } }) - } - } - }"""}, - update: { rule: """onColumn { - inProject { - role(filter: { permission: { eq: WRITE } } ) { - users(filter: { username: { eq: $USER } }) - } - } - }"""}, - delete: { rule: """onColumn { - inProject { - role(filter: { permission: { eq: WRITE } } ) { - users(filter: { username: { eq: $USER } }) - } - } - }"""} -){ - id: ID! - onColumn: Column! - title: String! - assignedTo: [User!] -} diff --git a/graphql/resolve/auth.go b/graphql/resolve/auth.go new file mode 100644 index 00000000000..a8270771976 --- /dev/null +++ b/graphql/resolve/auth.go @@ -0,0 +1,100 @@ +/* + * Copyright 2020 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package resolve + +import ( + "context" + "fmt" + "github.com/dgrijalva/jwt-go" + "github.com/pkg/errors" + "google.golang.org/grpc/metadata" + "net/http" + "time" +) + +//TODO: Get the secret key dynamically. +const ( + AuthJwtCtxKey = "authorizationJwt" + AuthHmacSecret = "Secretkey" +) + +// AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata. +func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context { + authorizationJwt := r.Header.Get("X-Dgraph-AuthorizationToken") + if authorizationJwt == "" { + return ctx + } + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.New(nil) + } + + md.Append(AuthJwtCtxKey, authorizationJwt) + ctx = metadata.NewIncomingContext(ctx, md) + return ctx +} + +type CustomClaims struct { + AuthVariables map[string]interface{} `json:"https://dgraph.io/jwt/claims"` + jwt.StandardClaims +} + +func ExtractAuthVariables(ctx context.Context) (map[string]interface{}, error) { + // Extract the jwt and unmarshal the jwt to get the auth variables. + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, nil + } + + jwtToken := md.Get(AuthJwtCtxKey) + if len(jwtToken) == 0 { + return nil, nil + } else if len(jwtToken) > 1 { + return nil, fmt.Errorf("invalid jwt auth token") + } + + return validateToken(jwtToken[0]) +} + +func validateToken(jwtStr string) (map[string]interface{}, error) { + token, err := + jwt.ParseWithClaims(jwtStr, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(AuthHmacSecret), nil + }) + + if err != nil { + return nil, errors.Errorf("unable to parse jwt token:%v", err) + } + + claims, ok := token.Claims.(*CustomClaims) + if !ok || !token.Valid { + return nil, errors.Errorf("claims in jwt token is not map claims") + } + + // by default, the MapClaims.Valid will return true if the exp field is not set + // here we enforce the checking to make sure that the refresh token has not expired + now := time.Now().Unix() + if !claims.VerifyExpiresAt(now, true) { + return nil, errors.Errorf("Token is expired") // the same error msg that's used inside jwt-go + } + + return claims.AuthVariables, nil +} diff --git a/graphql/resolve/auth_query_rewritter_test.go b/graphql/resolve/auth_query_rewritter_test.go deleted file mode 100644 index 7481cba6aec..00000000000 --- a/graphql/resolve/auth_query_rewritter_test.go +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package resolve - -import ( - "context" - "io/ioutil" - "testing" - - "github.com/dgraph-io/dgraph/graphql/dgraph" - "github.com/dgraph-io/dgraph/graphql/schema" - "github.com/dgraph-io/dgraph/graphql/test" - "github.com/dgrijalva/jwt-go" - "github.com/stretchr/testify/require" - _ "github.com/vektah/gqlparser/v2/validator/rules" // make gql validator init() all rules - "google.golang.org/grpc/metadata" - "gopkg.in/yaml.v2" -) - -// Tests showing that the query rewriter produces the expected Dgraph queries - -type AuthQueryRewritingCase struct { - Name string - GQLQuery string - Variables map[string]interface{} - DGQuery string - User string - Role string -} - -func TestAuthQueryRewriting(t *testing.T) { - b, err := ioutil.ReadFile("auth_tests.yaml") - require.NoError(t, err, "Unable to read test file") - - var tests []AuthQueryRewritingCase - err = yaml.Unmarshal(b, &tests) - require.NoError(t, err, "Unable to unmarshal tests to yaml.") - - testRewriter := NewQueryRewriter() - - type MyCustomClaims struct { - Foo map[string]interface{} `json:"https://dgraph.io/jwt/claims"` - jwt.StandardClaims - } - - // Create the Claims - claims := MyCustomClaims{ - map[string]interface{}{}, - jwt.StandardClaims{ - ExpiresAt: 15000, - Issuer: "test", - }, - } - claims.Foo["User"] = "user1" - - gqlSchema := test.LoadSchemaFromFile(t, "auth-schema.graphql") - for _, tcase := range tests { - t.Run(tcase.Name, func(t *testing.T) { - - op, err := gqlSchema.Operation( - &schema.Request{ - Query: tcase.GQLQuery, - Variables: tcase.Variables, - }) - require.NoError(t, err) - gqlQuery := test.GetQuery(t, op) - - ctx := context.Background() - claims.Foo["Role"] = tcase.Role - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - ss, err := token.SignedString([]byte("Secret")) - require.NoError(t, err) - - md := metadata.New(nil) - md.Append("authorizationJwt", ss) - ctx = metadata.NewIncomingContext(ctx, md) - - dgQuery, err := testRewriter.Rewrite(ctx, gqlQuery) - require.Nil(t, err) - require.Equal(t, tcase.DGQuery, dgraph.AsString(dgQuery)) - }) - } -} diff --git a/graphql/resolve/auth_query_test.yaml b/graphql/resolve/auth_query_test.yaml new file mode 100644 index 00000000000..eca8a8d9c9d --- /dev/null +++ b/graphql/resolve/auth_query_test.yaml @@ -0,0 +1,189 @@ +- name: "Auth with top level filter : query, no filter" + gqlquery: | + query { + queryUserSecret { + id + ownedBy + } + } + dgquery: |- + query { + queryUserSecret(func: uid(UserSecret1)) @filter(uid(UserSecret2)) { + id : uid + ownedBy : UserSecret.ownedBy + dgraph.uid : uid + } + UserSecret1 as var(func: type(UserSecret)) + UserSecret2 as var(func: uid(UserSecret1)) @filter(eq(UserSecret.ownedBy, "user1")) @cascade + } + +- name: "Auth with top level filter : get" + gqlquery: | + query { + getUserSecret(id: "0x123") { + id + ownedBy + } + } + dgquery: |- + query { + getUserSecret(func: uid(UserSecret1)) @filter(uid(UserSecret2)) { + id : uid + ownedBy : UserSecret.ownedBy + dgraph.uid : uid + } + UserSecret1 as var(func: uid(0x123)) @filter(type(UserSecret)) + UserSecret2 as var(func: uid(UserSecret1)) @filter(eq(UserSecret.ownedBy, "user1")) @cascade + } + +- name: "Auth with top level filter : query and filter" + gqlquery: | + query { + queryUserSecret(filter: { ownedBy: { eq: "user2" }}) { + id + ownedBy + } + } + dgquery: |- + query { + queryUserSecret(func: uid(UserSecret1)) @filter(uid(UserSecret2)) { + id : uid + ownedBy : UserSecret.ownedBy + dgraph.uid : uid + } + UserSecret1 as var(func: type(UserSecret)) @filter(eq(UserSecret.ownedBy, "user2")) + UserSecret2 as var(func: uid(UserSecret1)) @filter(eq(UserSecret.ownedBy, "user1")) @cascade + } + +- name: "Auth with deep filter : query top-level" + gqlquery: | + query { + queryTicket { + id + title + } + } + dgquery: |- + query { + queryTicket(func: uid(Ticket1)) @filter(uid(Ticket2)) { + id : uid + title : Ticket.title + dgraph.uid : uid + } + Ticket1 as var(func: type(Ticket)) + Ticket2 as var(func: uid(Ticket1)) @cascade { + onColumn : Ticket.onColumn { + inProject : Column.inProject { + roles : Project.roles @filter(eq(Role.permission, "VIEW")) { + assignedTo : Role.assignedTo @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } + } + +- name: "Auth with deep filter : query deep requires auth" + gqlquery: | + query { + queryUser { + username + tickets { + id + title + } + } + } + dgquery: |- + query { + queryUser(func: type(User)) { + username : User.username + tickets : User.tickets @filter(uid(Ticket1)) { + id : uid + title : Ticket.title + dgraph.uid : uid + } + dgraph.uid : uid + } + Ticket1 as var(func: type(Ticket)) @cascade { + onColumn : Ticket.onColumn { + inProject : Column.inProject { + roles : Project.roles @filter(eq(Role.permission, "VIEW")) { + assignedTo : Role.assignedTo @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } + } + +- name: "Auth with deep filter and field filter : query deep requires auth" + gqlquery: | + query { + queryUser { + username + tickets(filter: { title: { anyofterms: "graphql" } }) { + id + title + } + } + } + dgquery: |- + query { + queryUser(func: type(User)) { + username : User.username + tickets : User.tickets @filter((anyofterms(Ticket.title, "graphql") AND uid(Ticket1))) { + id : uid + title : Ticket.title + dgraph.uid : uid + } + dgraph.uid : uid + } + Ticket1 as var(func: type(Ticket)) @cascade { + onColumn : Ticket.onColumn { + inProject : Column.inProject { + roles : Project.roles @filter(eq(Role.permission, "VIEW")) { + assignedTo : Role.assignedTo @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } + } + +- name: "Auth with complex filter" + gqlquery: | + query { + queryMovie { + content + } + } + dgquery: |- + query { + queryMovie(func: uid(Movie1)) @filter((NOT (uid(Movie2)) AND (uid(Movie3) OR uid(Movie4)))) { + content : Movie.content + dgraph.uid : uid + } + Movie1 as var(func: type(Movie)) + Movie2 as var(func: uid(Movie1)) @filter(eq(Movie.hidden, true)) @cascade + Movie3 as var(func: uid(Movie1)) @cascade { + regionsAvailable : Movie.regionsAvailable { + users : Region.users @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + dgraph.uid : uid + } + Movie4 as var(func: uid(Movie1)) @cascade { + regionsAvailable : Movie.regionsAvailable @filter(eq(Region.global, true)) + dgraph.uid : uid + } + } diff --git a/graphql/resolve/auth_test.go b/graphql/resolve/auth_test.go new file mode 100644 index 00000000000..f8f1a4b5216 --- /dev/null +++ b/graphql/resolve/auth_test.go @@ -0,0 +1,226 @@ +/* + * Copyright 2019 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package resolve + +import ( + "context" + "io/ioutil" + "testing" + "time" + + "github.com/dgraph-io/dgraph/graphql/dgraph" + "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/graphql/test" + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/require" + _ "github.com/vektah/gqlparser/v2/validator/rules" // make gql validator init() all rules + "google.golang.org/grpc/metadata" + "gopkg.in/yaml.v2" +) + +type AuthQueryRewritingCase struct { + Name string + GQLQuery string + Variables map[string]interface{} + DGQuery string + User string + Role string +} + +// Tests showing that the query rewriter produces the expected Dgraph queries +// when it also needs to write in auth. +func TestAuthQueryRewriting(t *testing.T) { + b, err := ioutil.ReadFile("auth_query_test.yaml") + require.NoError(t, err, "Unable to read test file") + + var tests []AuthQueryRewritingCase + err = yaml.Unmarshal(b, &tests) + require.NoError(t, err, "Unable to unmarshal tests to yaml.") + + testRewriter := NewQueryRewriter() + + gqlSchema := test.LoadSchemaFromFile(t, "../e2e/auth/schema.graphql") + for _, tcase := range tests { + t.Run(tcase.Name, func(t *testing.T) { + + op, err := gqlSchema.Operation( + &schema.Request{ + Query: tcase.GQLQuery, + Variables: tcase.Variables, + }) + require.NoError(t, err) + gqlQuery := test.GetQuery(t, op) + + authVars := map[string]interface{}{ + "USER": "user1", + "ROLE": tcase.Role, + } + + ctx := addClaimsToContext(context.Background(), t, authVars) + + dgQuery, err := testRewriter.Rewrite(ctx, gqlQuery) + require.Nil(t, err) + require.Equal(t, tcase.DGQuery, dgraph.AsString(dgQuery)) + }) + } +} + +func addClaimsToContext( + ctx context.Context, + t *testing.T, + authVars map[string]interface{}) context.Context { + + claims := CustomClaims{ + authVars, + jwt.StandardClaims{ + ExpiresAt: time.Now().Add(time.Minute).Unix(), + Issuer: "test", + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + ss, err := token.SignedString([]byte(AuthHmacSecret)) + require.NoError(t, err) + + md := metadata.New(nil) + md.Append("authorizationJwt", ss) + return metadata.NewIncomingContext(ctx, md) +} + +// Tests that the queries that run after a mutation get auth correctly added in. +func TestAuthMutationQueryRewriting(t *testing.T) { + tests := map[string]struct { + gqlMut string + rewriter func() MutationRewriter + assigned map[string]string + result map[string]interface{} + dgQuery string + }{ + "Add Ticket": { + gqlMut: `mutation { + addTicket(input: [{title: "A ticket", onColumn: {colID: "0x1"}}]) { + ticket { + id + title + onColumn { + colID + name + } + } + } + }`, + rewriter: NewAddRewriter, + assigned: map[string]string{"Ticket1": "0x4"}, + dgQuery: `query { + ticket(func: uid(Ticket1)) @filter(uid(Ticket2)) { + id : uid + title : Ticket.title + onColumn : Ticket.onColumn { + colID : uid + name : Column.name + dgraph.uid : uid + } + dgraph.uid : uid + } + Ticket1 as var(func: uid(0x4)) + Ticket2 as var(func: uid(Ticket1)) @cascade { + onColumn : Ticket.onColumn { + inProject : Column.inProject { + roles : Project.roles @filter(eq(Role.permission, "VIEW")) { + assignedTo : Role.assignedTo @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } +}`, + }, + "Update Ticket": { + gqlMut: `mutation { + updateTicket(input: {filter: {id: ["0x4"]}, set: {title: "Updated title"} }) { + ticket { + id + title + onColumn { + colID + name + } + } + } + }`, + rewriter: NewUpdateRewriter, + result: map[string]interface{}{ + "updateTicket": []interface{}{map[string]interface{}{"uid": "0x4"}}}, + dgQuery: `query { + ticket(func: uid(Ticket1)) @filter(uid(Ticket2)) { + id : uid + title : Ticket.title + onColumn : Ticket.onColumn { + colID : uid + name : Column.name + dgraph.uid : uid + } + dgraph.uid : uid + } + Ticket1 as var(func: uid(0x4)) + Ticket2 as var(func: uid(Ticket1)) @cascade { + onColumn : Ticket.onColumn { + inProject : Column.inProject { + roles : Project.roles @filter(eq(Role.permission, "VIEW")) { + assignedTo : Role.assignedTo @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } + dgraph.uid : uid + } +}`, + }, + } + + gqlSchema := test.LoadSchemaFromFile(t, "../e2e/auth/schema.graphql") + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + // -- Arrange -- + rewriter := tt.rewriter() + op, err := gqlSchema.Operation(&schema.Request{Query: tt.gqlMut}) + require.NoError(t, err) + gqlMutation := test.GetMutation(t, op) + authVars := map[string]interface{}{ + "USER": "user1", + } + ctx := addClaimsToContext(context.Background(), t, authVars) + _, err = rewriter.Rewrite(ctx, gqlMutation) + require.Nil(t, err) + + // -- Act -- + dgQuery, err := rewriter.FromMutationResult( + ctx, gqlMutation, tt.assigned, tt.result) + + // -- Assert -- + require.Nil(t, err) + require.Equal(t, tt.dgQuery, dgraph.AsString(dgQuery)) + }) + + } +} diff --git a/graphql/resolve/mutation_rewriter.go b/graphql/resolve/mutation_rewriter.go index 4ddac055382..36ff52564fd 100644 --- a/graphql/resolve/mutation_rewriter.go +++ b/graphql/resolve/mutation_rewriter.go @@ -346,7 +346,16 @@ func (mrw *AddRewriter) FromMutationResult( errs = schema.AsGQLErrors(errors.Errorf("no new node was created")) } - return rewriteAsQueryByIds(mutation.QueryField(), uids), errs + authVariables, err := ExtractAuthVariables(ctx) + if err != nil { + return nil, err + } + authRw := &authRewriter{ + authVariables: authVariables, + varGen: NewVariableGenerator(), + selector: queryAuthSelector, + } + return rewriteAsQueryByIds(mutation.QueryField(), uids, authRw), errs } // Rewrite rewrites set and remove update patches into GraphQL+- upsert mutations. @@ -479,7 +488,16 @@ func (urw *UpdateRewriter) FromMutationResult( } } - return rewriteAsQueryByIds(mutation.QueryField(), uids), nil + authVariables, err := ExtractAuthVariables(ctx) + if err != nil { + return nil, err + } + authRw := &authRewriter{ + authVariables: authVariables, + varGen: NewVariableGenerator(), + selector: queryAuthSelector, + } + return rewriteAsQueryByIds(mutation.QueryField(), uids, authRw), nil } func extractMutated(result map[string]interface{}, mutatedField string) []string { diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index 60e29c7e34a..dcf31b03917 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -30,15 +30,35 @@ import ( type queryRewriter struct{} +type authRewriter struct { + authVariables map[string]interface{} + isWritingAuth bool + selector func(t schema.Type) *schema.RuleNode + varGen *VariableGenerator + varName string +} + // NewQueryRewriter returns a new QueryRewriter. func NewQueryRewriter() QueryRewriter { return &queryRewriter{} } // Rewrite rewrites a GraphQL query into a Dgraph GraphQuery. -func (qr *queryRewriter) Rewrite(ctx context.Context, +func (qr *queryRewriter) Rewrite( + ctx context.Context, gqlQuery schema.Query) (*gql.GraphQuery, error) { + authVariables, err := ExtractAuthVariables(ctx) + if err != nil { + return nil, err + } + + authRw := &authRewriter{ + authVariables: authVariables, + varGen: NewVariableGenerator(), + selector: queryAuthSelector, + } + switch gqlQuery.QueryType() { case schema.GetQuery: @@ -56,27 +76,25 @@ func (qr *queryRewriter) Rewrite(ctx context.Context, return nil, err } - dgQuery := rewriteAsGet(gqlQuery, uid, xid) - addTypeFilter(dgQuery, gqlQuery.Type()) - + dgQuery := rewriteAsGet(gqlQuery, uid, xid, authRw) return dgQuery, nil case schema.FilterQuery: - return rewriteAsQuery(gqlQuery), nil + return rewriteAsQuery(gqlQuery, authRw), nil case schema.PasswordQuery: - return passwordQuery(gqlQuery) + return passwordQuery(gqlQuery, authRw) default: return nil, errors.Errorf("unimplemented query type %s", gqlQuery.QueryType()) } } -func passwordQuery(m schema.Query) (*gql.GraphQuery, error) { +func passwordQuery(m schema.Query, authRw *authRewriter) (*gql.GraphQuery, error) { xid, uid, err := m.IDArgValue() if err != nil { return nil, err } - dgQuery := rewriteAsGet(m, uid, xid) + dgQuery := rewriteAsGet(m, uid, xid, authRw) queriedType := m.Type() name := queriedType.PasswordField().Name() @@ -164,7 +182,7 @@ func addUID(dgQuery *gql.GraphQuery) { dgQuery.Children = append(dgQuery.Children, uidChild) } -func rewriteAsQueryByIds(field schema.Field, uids []uint64) *gql.GraphQuery { +func rewriteAsQueryByIds(field schema.Field, uids []uint64, authRw *authRewriter) *gql.GraphQuery { dgQuery := &gql.GraphQuery{ Attr: field.ResponseName(), Func: &gql.Function{ @@ -178,6 +196,14 @@ func rewriteAsQueryByIds(field schema.Field, uids []uint64) *gql.GraphQuery { } addArgumentsToField(dgQuery, field) + selectionAuth := addSelectionSetFrom(dgQuery, field, authRw) + addUID(dgQuery) + + dgQuery = authRw.addAuthQueries(field, dgQuery) + if len(selectionAuth) > 0 { + dgQuery = &gql.GraphQuery{Children: append([]*gql.GraphQuery{dgQuery}, selectionAuth...)} + } + return dgQuery } @@ -188,13 +214,30 @@ func addArgumentsToField(dgQuery *gql.GraphQuery, field schema.Field) { addFilter(dgQuery, field.Type(), filter) addOrder(dgQuery, field) addPagination(dgQuery, field) - addSelectionSetFrom(dgQuery, field) - addUID(dgQuery) } -func rewriteAsGet(field schema.Field, uid uint64, xid *string) *gql.GraphQuery { +func rewriteAsGet( + field schema.Field, + uid uint64, + xid *string, + auth *authRewriter) *gql.GraphQuery { + + var dgQuery *gql.GraphQuery + if xid == nil { - return rewriteAsQueryByIds(field, []uint64{uid}) + dgQuery = rewriteAsQueryByIds(field, []uint64{uid}, auth) + + // If the top level query is the named get, put the type filter there, otherwise + // auth has been written into the query, then there will be a blank top level + // and multiple children, of which the second is the actual get + if dgQuery.Attr != "" { + addTypeFilter(dgQuery, field.Type()) + } else { + addTypeFilter(dgQuery.Children[1], field.Type()) + } + + return dgQuery + } xidArgName := field.XIDArg() @@ -206,7 +249,6 @@ func rewriteAsGet(field schema.Field, uid uint64, xid *string) *gql.GraphQuery { }, } - var dgQuery *gql.GraphQuery if uid > 0 { dgQuery = &gql.GraphQuery{ Attr: field.ResponseName(), @@ -225,26 +267,201 @@ func rewriteAsGet(field schema.Field, uid uint64, xid *string) *gql.GraphQuery { Func: eqXidFunc, } } - addSelectionSetFrom(dgQuery, field) + selectionAuth := addSelectionSetFrom(dgQuery, field, auth) addUID(dgQuery) + addTypeFilter(dgQuery, field.Type()) + + dgQuery = auth.addAuthQueries(field, dgQuery) + if len(selectionAuth) > 0 { + dgQuery = &gql.GraphQuery{Children: append([]*gql.GraphQuery{dgQuery}, selectionAuth...)} + } + return dgQuery } -func rewriteAsQuery(field schema.Field) *gql.GraphQuery { +func rewriteAsQuery(field schema.Field, authRw *authRewriter) *gql.GraphQuery { dgQuery := &gql.GraphQuery{ Attr: field.ResponseName(), } - if ids := idFilter(field, field.Type().IDField()); ids != nil { + if authRw != nil && authRw.isWritingAuth && authRw.varName != "" { + // When rewriting auth rules, they always start like + // Todo2 as var(func: uid(Todo1)) @cascade { + // Where Todo1 is the variable generated from the filter of the field + // we are adding auth to. + // + // TODO: Currently this only applies at the top level. This means auth queries + // from the top level query/get are as efficient as the original query (because + // they start from the uid(Todo1) of the user query) ... however auth queries + // on deeper fields will start like `func: type(Todo)`, that's ok for building + // the feature and getting all the testing in place, but we should improve this so + // that the internal auth queries start from exactly the possible nodes that the + // internal field is considering. + authRw.addVariableUIDFunc(dgQuery) + } else if ids := idFilter(field, field.Type().IDField()); ids != nil { addUIDFunc(dgQuery, ids) } else { addTypeFunc(dgQuery, field.Type().DgraphName()) } addArgumentsToField(dgQuery, field) + selectionAuth := addSelectionSetFrom(dgQuery, field, authRw) + addUID(dgQuery) + + dgQuery = authRw.addAuthQueries(field, dgQuery) + + if len(selectionAuth) > 0 { + dgQuery = &gql.GraphQuery{Children: append([]*gql.GraphQuery{dgQuery}, selectionAuth...)} + } + return dgQuery } +// addAuthQueries takes a field and the GraphQuery that has so far been constructed for +// the field and builds any auth queries that are need to restrict the result to only +// the nodes authorized to be queried, returning a new graphQuery that does the +// original query and the auth. +func (authRw *authRewriter) addAuthQueries( + field schema.Field, + dgQuery *gql.GraphQuery) *gql.GraphQuery { + + // There's no need to recursively inject auth queries into other auth queries, so if + // we are already generating an auth query, there's nothing to add. + if authRw == nil || authRw.isWritingAuth { + return dgQuery + } + + authRw.varName = authRw.varGen.Next(field.Type(), "", "") + + fldAuthQueries, filter := authRw.rewriteAuthQueries(field) + if len(fldAuthQueries) == 0 { + return dgQuery + } + + // build a query like + // Todo1 as var(func: ... ) @filter(...) + // that has the filter from the user query in it. This is then used as + // the starting point for both the user query and the auth query. + // + // We already have the query, so just copy it and modify the original + varQry := &gql.GraphQuery{ + Var: authRw.varName, + Attr: "var", + Func: dgQuery.Func, + Filter: dgQuery.Filter, + } + + // The user query starts from the var query generated above and is filtered + // by the the filter generated from auth processing, so now we build + // queryTodo(func: uid(Todo1)) @filter(...auth-queries...) { ... } + dgQuery.Func = &gql.Function{ + Name: "uid", + Args: []gql.Arg{{Value: authRw.varName}}, + } + dgQuery.Filter = filter + + // The final query that includes the user's filter and auth processsing is thus like + // + // queryTodo(func: uid(Todo1)) @filter(uid(Todo2) AND uid(Todo3)) { ... } + // Todo1 as var(func: ... ) @filter(...) + // Todo2 as var(func: uid(Todo1)) @cascade { ...auth query 1... } + // Todo3 as var(func: uid(Todo1)) @cascade { ...auth query 2... } + return &gql.GraphQuery{Children: append([]*gql.GraphQuery{dgQuery, varQry}, fldAuthQueries...)} +} + +func (authRw *authRewriter) addVariableUIDFunc(q *gql.GraphQuery) { + q.Func = &gql.Function{ + Name: "uid", + Args: []gql.Arg{{Value: authRw.varName}}, + } +} + +func queryAuthSelector(t schema.Type) *schema.RuleNode { + auth := t.AuthRules() + if auth == nil || auth.Rules == nil { + return nil + } + + return auth.Rules.Query +} + +func (authRw *authRewriter) rewriteAuthQueries(f schema.Field) ([]*gql.GraphQuery, *gql.FilterTree) { + if authRw == nil || authRw.isWritingAuth { + return nil, nil + } + + return (&authRewriter{ + authVariables: authRw.authVariables, + varGen: authRw.varGen, + isWritingAuth: true, + varName: authRw.varName, + selector: authRw.selector, + }).rewriteRuleNode(f, authRw.selector(f.Type())) +} + +func (authRw *authRewriter) rewriteRuleNode( + field schema.Field, + rn *schema.RuleNode) ([]*gql.GraphQuery, *gql.FilterTree) { + + if field == nil || rn == nil { + return nil, nil + } + + nodeList := func( + field schema.Field, + rns []*schema.RuleNode) ([]*gql.GraphQuery, []*gql.FilterTree) { + + var qrys []*gql.GraphQuery + var filts []*gql.FilterTree + for _, orRn := range rns { + q, f := authRw.rewriteRuleNode(field, orRn) + qrys = append(qrys, q...) + filts = append(filts, f) + } + return qrys, filts + } + + switch { + case len(rn.And) > 0: + qrys, filts := nodeList(field, rn.And) + return qrys, &gql.FilterTree{ + Op: "and", + Child: filts, + } + case len(rn.Or) > 0: + qrys, filts := nodeList(field, rn.Or) + return qrys, &gql.FilterTree{ + Op: "or", + Child: filts, + } + case rn.Not != nil: + qrys, filter := authRw.rewriteRuleNode(field, rn.Not) + return qrys, &gql.FilterTree{ + Op: "not", + Child: []*gql.FilterTree{filter}, + } + case rn.Rule != nil: + // create a copy of the auth query that's specialized for the values from the JWT + qry := rn.Rule.AuthFor(field, authRw.authVariables) + + // build + // Todo2 as var(func: uid(Todo1)) @cascade { ...auth query 1... } + varName := authRw.varGen.Next(field.Type(), "", "") + r1 := rewriteAsQuery(qry, authRw) + r1.Var = varName + r1.Attr = "var" + r1.Cascade = true + + return []*gql.GraphQuery{r1}, &gql.FilterTree{ + Func: &gql.Function{ + Name: "uid", + Args: []gql.Arg{{Value: varName}}, + }, + } + } + return nil, nil +} + func addTypeFilter(q *gql.GraphQuery, typ schema.Type) { thisFilter := &gql.FilterTree{ Func: &gql.Function{ @@ -278,7 +495,15 @@ func addTypeFunc(q *gql.GraphQuery, typ string) { } -func addSelectionSetFrom(q *gql.GraphQuery, field schema.Field) { +// addSelectionSetFrom adds all the selections from field into q, and returns a list +// of extra queries needed to satisfy auth requirements +func addSelectionSetFrom( + q *gql.GraphQuery, + field schema.Field, + auth *authRewriter) []*gql.GraphQuery { + + var authQueries []*gql.GraphQuery + // Only add dgraph.type as a child if this field is an interface type and has some children. // dgraph.type would later be used in completeObject as different objects in the resulting // JSON would return different fields based on their concrete type. @@ -314,10 +539,24 @@ func addSelectionSetFrom(q *gql.GraphQuery, field schema.Field) { addOrder(child, f) addPagination(child, f) - addSelectionSetFrom(child, f) - + selectionAuth := addSelectionSetFrom(child, f, auth) q.Children = append(q.Children, child) + + fieldAuth, authFilter := auth.rewriteAuthQueries(f) + authQueries = append(authQueries, selectionAuth...) + authQueries = append(authQueries, fieldAuth...) + if len(fieldAuth) > 0 { + if child.Filter == nil { + child.Filter = authFilter + } else { + child.Filter = &gql.FilterTree{ + Op: "and", + Child: []*gql.FilterTree{child.Filter, authFilter}, + } + } + } } + return authQueries } func addOrder(q *gql.GraphQuery, field schema.Field) { diff --git a/graphql/resolve/query_test.yaml b/graphql/resolve/query_test.yaml index e96b5ec77e1..e5f942b26c2 100644 --- a/graphql/resolve/query_test.yaml +++ b/graphql/resolve/query_test.yaml @@ -1154,17 +1154,17 @@ name: "Password query" gqlquery: | query { - checkUserPassword(name: "0x1", pwd: "Password") { + checkUserPassword(name: "user1", pwd: "Password") { name } } dgquery: |- query { - checkUserPassword(func: eq(User.name, "0x1")) @filter((eq(val(pwd), 1))) { + checkUserPassword(func: eq(User.name, "user1")) @filter((eq(val(pwd), 1) AND type(User))) { name : User.name dgraph.uid : uid } - checkPwd(func: eq(User.name, "0x1")) { + checkPwd(func: eq(User.name, "user1")) @filter(type(User)) { pwd as checkpwd(User.pwd, "Password") } } diff --git a/graphql/schema/auth.go b/graphql/schema/auth.go index 578a2e62928..14dd5aa06b9 100644 --- a/graphql/schema/auth.go +++ b/graphql/schema/auth.go @@ -51,8 +51,8 @@ type AuthContainer struct { } type TypeAuth struct { - rules *AuthContainer - fields map[string]*AuthContainer + Rules *AuthContainer + Fields map[string]*AuthContainer } func authRules(s *ast.Schema) (map[string]*TypeAuth, error) { @@ -62,17 +62,17 @@ func authRules(s *ast.Schema) (map[string]*TypeAuth, error) { for _, typ := range s.Types { name := typeName(typ) - authRules[name] = &TypeAuth{fields: make(map[string]*AuthContainer)} + authRules[name] = &TypeAuth{Fields: make(map[string]*AuthContainer)} auth := typ.Directives.ForName(authDirective) if auth != nil { - authRules[name].rules, err = parseAuthDirective(s, typ, auth) + authRules[name].Rules, err = parseAuthDirective(s, typ, auth) errResult = AppendGQLErrs(errResult, err) } for _, field := range typ.Fields { auth := field.Directives.ForName(authDirective) if auth != nil { - authRules[name].fields[field.Name], err = parseAuthDirective(s, typ, auth) + authRules[name].Fields[field.Name], err = parseAuthDirective(s, typ, auth) errResult = AppendGQLErrs(errResult, err) } } diff --git a/graphql/schema/testdata/schemagen/output/authorization.graphql b/graphql/schema/testdata/schemagen/output/authorization.graphql index c5f070fbf15..f2c6e757f9f 100644 --- a/graphql/schema/testdata/schemagen/output/authorization.graphql +++ b/graphql/schema/testdata/schemagen/output/authorization.graphql @@ -269,3 +269,14 @@ type Mutation { updateUser(input: UpdateUserInput!): UpdateUserPayload deleteUser(filter: UserFilter!): DeleteUserPayload } + +####################### +# Generated Subscriptions +####################### + +type Subscription { + getTodo(id: ID!): Todo + queryTodo(filter: TodoFilter, order: TodoOrder, first: Int, offset: Int): [Todo] + getUser(username: String!): User + queryUser(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User] +} diff --git a/graphql/schema/wrappers.go b/graphql/schema/wrappers.go index fe1e9ed8b98..2b160dbe0a6 100644 --- a/graphql/schema/wrappers.go +++ b/graphql/schema/wrappers.go @@ -97,6 +97,7 @@ type Field interface { IncludeInterfaceField(types []interface{}) bool TypeName(dgraphTypes []interface{}) string GetObjectName() string + IsAuthQuery() bool } // A Mutation is a field (from the schema's Mutation type) from an Operation @@ -113,6 +114,7 @@ type Query interface { Field QueryType() QueryType Rename(newName string) + AuthFor(f Field, jwtVars map[string]interface{}) Query } // A Type is a GraphQL type like: Float, T, T! and [T!]!. If it's not a list, then @@ -132,6 +134,7 @@ type Type interface { Interfaces() []string EnsureNonNulls(map[string]interface{}, string) error FieldOriginatedFrom(fieldName string) string + AuthRules() *TypeAuth fmt.Stringer } @@ -149,7 +152,7 @@ type FieldDefinition interface { type astType struct { typ *ast.Type - inSchema *ast.Schema + inSchema *schema dgraphPredicate map[string]map[string]string } @@ -190,7 +193,7 @@ type field struct { type fieldDefinition struct { fieldDef *ast.FieldDefinition - inSchema *ast.Schema + inSchema *schema dgraphPredicate map[string]map[string]string } @@ -427,14 +430,14 @@ func dgraphMapping(sch *ast.Schema) map[string]map[string]string { return dgraphPredicate } -func mutatedTypeMapping(s *ast.Schema, +func mutatedTypeMapping(s *schema, dgraphPredicate map[string]map[string]string) map[string]*astType { - if s.Mutation == nil { + if s.schema.Mutation == nil { return nil } - m := make(map[string]*astType, len(s.Mutation.Fields)) - for _, field := range s.Mutation.Fields { + m := make(map[string]*astType, len(s.schema.Mutation.Fields)) + for _, field := range s.schema.Mutation.Fields { mutatedTypeName := "" switch { case strings.HasPrefix(field.Name, "add"): @@ -450,8 +453,8 @@ func mutatedTypeMapping(s *ast.Schema, // the type from the definition of an object. We use Update and not Add here because // Interfaces only have Update. var def *ast.Definition - if def = s.Types["Update"+mutatedTypeName+"Payload"]; def == nil { - def = s.Types["Add"+mutatedTypeName+"Payload"] + if def = s.schema.Types["Update"+mutatedTypeName+"Payload"]; def == nil { + def = s.schema.Types["Add"+mutatedTypeName+"Payload"] } if def == nil { @@ -490,13 +493,16 @@ func AsSchema(s *ast.Schema) (Schema, error) { } dgraphPredicate := dgraphMapping(s) - return &schema{ + + sch := &schema{ schema: s, dgraphPredicate: dgraphPredicate, - mutatedType: mutatedTypeMapping(s, dgraphPredicate), typeNameAst: typeMappings(s), authRules: authRules, - }, nil + } + sch.mutatedType = mutatedTypeMapping(sch, dgraphPredicate) + + return sch, nil } func responseName(f *ast.Field) string { @@ -532,6 +538,10 @@ func (f *field) SetArgTo(arg string, val interface{}) { } } +func (f *field) IsAuthQuery() bool { + return f.field.Arguments.ForName("dgraph.uid") != nil +} + func (f *field) ArgValue(name string) interface{} { if f.arguments == nil { // Compute and cache the map first time this function is called for a field. @@ -631,7 +641,7 @@ func (f *field) Type() Type { } return &astType{ typ: t, - inSchema: f.op.inSchema.schema, + inSchema: f.op.inSchema, dgraphPredicate: f.op.inSchema.dgraphPredicate, } } @@ -715,6 +725,30 @@ func (f *field) IncludeInterfaceField(dgraphTypes []interface{}) bool { return false } +func (q *query) IsAuthQuery() bool { + return (*field)(q).field.Arguments.ForName("dgraph.uid") != nil +} + +func (q *query) AuthFor(f Field, jwtVars map[string]interface{}) Query { + // copy the template, so that multiple queries can run rewriting for the rule. + var sch *schema + if fld, ok := f.(*field); ok { + sch = fld.op.inSchema + } else { + sch = f.(*query).op.inSchema + } + + return &query{ + field: (*field)(q).field, + op: &operation{op: q.op.op, + query: q.op.query, + doc: q.op.doc, + inSchema: sch, + vars: jwtVars, + }, + sel: q.sel} +} + func (q *query) Rename(newName string) { q.field.Name = newName } @@ -936,10 +970,18 @@ func (m *mutation) IncludeInterfaceField(dgraphTypes []interface{}) bool { return (*field)(m).IncludeInterfaceField(dgraphTypes) } +func (m *mutation) IsAuthQuery() bool { + return (*field)(m).field.Arguments.ForName("dgraph.uid") != nil +} + +func (t *astType) AuthRules() *TypeAuth { + return t.inSchema.authRules[t.Name()] +} + func (t *astType) Field(name string) FieldDefinition { return &fieldDefinition{ // this ForName lookup is a loop in the underlying schema :-( - fieldDef: t.inSchema.Types[t.Name()].Fields.ForName(name), + fieldDef: t.inSchema.schema.Types[t.Name()].Fields.ForName(name), inSchema: t.inSchema, dgraphPredicate: t.dgraphPredicate, } @@ -948,7 +990,7 @@ func (t *astType) Field(name string) FieldDefinition { func (t *astType) Fields() []FieldDefinition { var result []FieldDefinition - for _, fld := range t.inSchema.Types[t.Name()].Fields { + for _, fld := range t.inSchema.schema.Types[t.Name()].Fields { result = append(result, &fieldDefinition{ fieldDef: fld, @@ -998,7 +1040,7 @@ func (fd *fieldDefinition) Inverse() FieldDefinition { } // typ must exist if the schema passed GQL validation - typ := fd.inSchema.Types[fd.Type().Name()] + typ := fd.inSchema.schema.Types[fd.Type().Name()] // fld must exist if the schema passed our validation fld := typ.Fields.ForName(invFieldArg.Value.Raw) @@ -1030,7 +1072,7 @@ func (fd *fieldDefinition) ForwardEdge() FieldDefinition { fedge := strings.Trim(name, "<~>") // typ must exist if the schema passed GQL validation - typ := fd.inSchema.Types[fd.Type().Name()] + typ := fd.inSchema.schema.Types[fd.Type().Name()] var fld *ast.FieldDefinition // Have to range through all the fields and find the correct forward edge. This would be @@ -1064,7 +1106,7 @@ func (t *astType) Name() string { } func (t *astType) DgraphName() string { - typeDef := t.inSchema.Types[t.typ.Name()] + typeDef := t.inSchema.schema.Types[t.typ.Name()] name := typeName(typeDef) if name != "" { return name @@ -1118,7 +1160,7 @@ func (t *astType) String() string { } func (t *astType) IDField() FieldDefinition { - def := t.inSchema.Types[t.Name()] + def := t.inSchema.schema.Types[t.Name()] if def.Kind != ast.Object && def.Kind != ast.Interface { return nil } @@ -1136,7 +1178,7 @@ func (t *astType) IDField() FieldDefinition { } func (t *astType) PasswordField() FieldDefinition { - def := t.inSchema.Types[t.Name()] + def := t.inSchema.schema.Types[t.Name()] if def.Kind != ast.Object && def.Kind != ast.Interface { return nil } @@ -1153,7 +1195,7 @@ func (t *astType) PasswordField() FieldDefinition { } func (t *astType) XIDField() FieldDefinition { - def := t.inSchema.Types[t.Name()] + def := t.inSchema.schema.Types[t.Name()] if def.Kind != ast.Object && def.Kind != ast.Interface { return nil } @@ -1171,7 +1213,7 @@ func (t *astType) XIDField() FieldDefinition { } func (t *astType) Interfaces() []string { - interfaces := t.inSchema.Types[t.typ.Name()].Interfaces + interfaces := t.inSchema.schema.Types[t.typ.Name()].Interfaces if len(interfaces) == 0 { return nil } @@ -1180,7 +1222,7 @@ func (t *astType) Interfaces() []string { // overwritten using @dgraph(type: ...) names := make([]string, 0, len(interfaces)) for _, intr := range interfaces { - i := t.inSchema.Types[intr] + i := t.inSchema.schema.Types[intr] name := intr if n := typeName(i); n != "" { name = n @@ -1226,7 +1268,7 @@ func (t *astType) Interfaces() []string { // and then check ourselves that either there's an ID, or there's all the bits to // satisfy a valid post. func (t *astType) EnsureNonNulls(obj map[string]interface{}, exclusion string) error { - for _, fld := range t.inSchema.Types[t.Name()].Fields { + for _, fld := range t.inSchema.schema.Types[t.Name()].Fields { if fld.Type.NonNull && !isID(fld) && !(fld.Name == exclusion) { if val, ok := obj[fld.Name]; !ok || val == nil { return errors.Errorf( @@ -1242,13 +1284,13 @@ func (t *astType) EnsureNonNulls(obj map[string]interface{}, exclusion string) e // If the field wasn't inherited, but belonged to this type, this type's name is returned. // Otherwise, empty string is returned. func (t *astType) FieldOriginatedFrom(fieldName string) string { - for _, implements := range t.inSchema.Implements[t.Name()] { + for _, implements := range t.inSchema.schema.Implements[t.Name()] { if implements.Fields.ForName(fieldName) != nil { return implements.Name } } - if t.inSchema.Types[t.Name()].Fields.ForName(fieldName) != nil { + if t.inSchema.schema.Types[t.Name()].Fields.ForName(fieldName) != nil { return t.Name() } diff --git a/graphql/schema/wrappers_test.go b/graphql/schema/wrappers_test.go index c63dab93edf..d9306f3935e 100644 --- a/graphql/schema/wrappers_test.go +++ b/graphql/schema/wrappers_test.go @@ -309,7 +309,7 @@ func TestCheckNonNulls(t *testing.T) { typ := &astType{ typ: &ast.Type{NamedType: "T"}, - inSchema: (gqlSchema.(*schema)).schema, + inSchema: (gqlSchema.(*schema)), } for name, test := range tcases { diff --git a/graphql/test/test.go b/graphql/test/test.go index 4609ca23b85..e26dab728a3 100644 --- a/graphql/test/test.go +++ b/graphql/test/test.go @@ -35,16 +35,13 @@ import ( func LoadSchema(t *testing.T, gqlSchema string) schema.Schema { doc, gqlErr := parser.ParseSchemas(validator.Prelude, &ast.Source{Input: gqlSchema}) - require.Nil(t, gqlErr) - // ^^ We can't use NoError here because gqlErr is of type *gqlerror.Error, - // so passing into something that just expects an error, will always be a - // non-nil interface. + requireNoGQLErrors(t, gqlErr) gql, gqlErr := validator.ValidateSchemaDocument(doc) - require.Nil(t, gqlErr) + requireNoGQLErrors(t, gqlErr) schema, err := schema.AsSchema(gql) - require.Nil(t, err) + requireNoGQLErrors(t, err) return schema } @@ -60,7 +57,7 @@ func LoadSchemaFromFile(t *testing.T, gqlFile string) schema.Schema { func LoadSchemaFromString(t *testing.T, sch string) schema.Schema { handler, err := schema.NewHandler(string(sch)) - require.NoError(t, err, "input schema contained errors") + requireNoGQLErrors(t, err) return LoadSchema(t, handler.GQLSchema()) } @@ -111,3 +108,16 @@ func RequireJSONEqStr(t *testing.T, expected string, got interface{}) { require.JSONEq(t, expected, string(jsonGot)) } + +func requireNoGQLErrors(t *testing.T, err error) { + require.Nil(t, err, + "required no GraphQL errors, but received :\n%s", serializeOrError(err)) +} + +func serializeOrError(toSerialize interface{}) string { + byts, err := json.Marshal(toSerialize) + if err != nil { + return "unable to serialize because " + err.Error() + } + return string(byts) +} diff --git a/graphql/web/http.go b/graphql/web/http.go index 6d7997624ff..b1214c44095 100644 --- a/graphql/web/http.go +++ b/graphql/web/http.go @@ -152,6 +152,7 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { x.Panic(errors.New("graphqlHandler not initialised")) } + ctx = resolve.AttachAuthorizationJwt(ctx, r) ctx = x.AttachAccessJwt(ctx, r) if ip, port, err := net.SplitHostPort(r.RemoteAddr); err == nil {