Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(GraphQL): Add GraphQL schema validation Endpoint. #6250

Merged
merged 11 commits into from
Sep 1, 2020
22 changes: 22 additions & 0 deletions dgraph/cmd/alpha/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io/ioutil"
"log"
Expand Down Expand Up @@ -508,6 +509,27 @@ func setupServer(closer *y.Closer) {
adminSchemaHandler(w, r, adminServer)
})))

http.Handle("/admin/schema/validate", http.HandlerFunc(func(w http.ResponseWriter,
r *http.Request) {
schema := readRequest(w, r)
w.Header().Set("Content-Type", "application/json")

err := admin.SchemaValidate(string(schema))
if err == nil {
w.WriteHeader(http.StatusOK)
x.Check2(w.Write([]byte(`{"valid":true}`)))
return
}

w.WriteHeader(http.StatusBadRequest)
errs := strings.Split(strings.TrimSpace(err.Error()), "\n")
errJson, err := json.Marshal(errs)
if err != nil {
errJson = []byte(err.Error())
}
x.Check2(w.Write([]byte(fmt.Sprintf(`{"valid":false, "error" : %s}`, errJson))))
}))

http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true},
adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
shutDownHandler(w, r, adminServer)
Expand Down
14 changes: 13 additions & 1 deletion graphql/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,18 @@ var (
mainHealthStore = &GraphQLHealthStore{}
)

func SchemaValidate(sch string) error {
schHandler, err := schema.NewHandler(sch, true)
if err != nil {
return err
}

if _, err := schema.FromString(schHandler.GQLSchema()); err != nil {
return err
}
return nil
}

// GraphQLHealth is used to report the health status of a GraphQL server.
// It is required for kubernetes probing.
type GraphQLHealth struct {
Expand Down Expand Up @@ -602,7 +614,7 @@ func getCurrentGraphQLSchema() (*gqlSchema, error) {
}

func generateGQLSchema(sch *gqlSchema) (schema.Schema, error) {
schHandler, err := schema.NewHandler(sch.Schema)
schHandler, err := schema.NewHandler(sch.Schema, false)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion graphql/admin/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ func resolveUpdateGQLSchema(ctx context.Context, m schema.Mutation) (*resolve.Re
return resolve.EmptyResult(m, err), false
}

schHandler, err := schema.NewHandler(input.Set.Schema)
// We just need to validate the schema. Schema is later set in `resetSchema()` when the schema
// is returned from badger.
schHandler, err := schema.NewHandler(input.Set.Schema, true)
if err != nil {
return resolve.EmptyResult(m, err), false
}
Expand Down
82 changes: 52 additions & 30 deletions graphql/authorization/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"regexp"
"strings"
"sync"

"github.com/vektah/gqlparser/v2/gqlerror"

Expand All @@ -46,7 +47,7 @@ const (
)

var (
metainfo AuthMeta
authMeta = &AuthMeta{}
)

type AuthMeta struct {
Expand All @@ -56,6 +57,7 @@ type AuthMeta struct {
Namespace string
Algo string
Audience []string
sync.RWMutex
}

// Validate required fields.
Expand Down Expand Up @@ -83,17 +85,17 @@ func (a *AuthMeta) validate() error {
return nil
}

func Parse(schema string) (AuthMeta, error) {
func Parse(schema string) (*AuthMeta, error) {
var meta AuthMeta
authInfoIdx := strings.LastIndex(schema, AuthMetaHeader)
if authInfoIdx == -1 {
return meta, nil
return nil, nil
}
authInfo := schema[authInfoIdx:]

err := json.Unmarshal([]byte(authInfo[len(AuthMetaHeader):]), &meta)
if err == nil {
return meta, meta.validate()
return &meta, meta.validate()
}

fmt.Println("Falling back to parsing `Dgraph.Authorization` in old format." +
Expand All @@ -112,60 +114,77 @@ func Parse(schema string) (AuthMeta, error) {
authMetaRegex, err :=
regexp.Compile(`^#[\s]([^\s]+)[\s]+([^\s]+)[\s]+([^\s]+)[\s]+([^\s]+)[\s]+"([^\"]+)"`)
if err != nil {
return meta, gqlerror.Errorf("JWT parsing failed: %v", err)
return nil, gqlerror.Errorf("JWT parsing failed: %v", err)
}

idx := authMetaRegex.FindAllStringSubmatchIndex(authInfo, -1)
if len(idx) != 1 || len(idx[0]) != 12 ||
!strings.HasPrefix(authInfo, authInfo[idx[0][0]:idx[0][1]]) {
return meta, gqlerror.Errorf("Invalid `Dgraph.Authorization` format: %s", authInfo)
return nil, gqlerror.Errorf("Invalid `Dgraph.Authorization` format: %s", authInfo)
}

meta.Header = authInfo[idx[0][4]:idx[0][5]]
meta.Namespace = authInfo[idx[0][6]:idx[0][7]]
meta.Algo = authInfo[idx[0][8]:idx[0][9]]
meta.VerificationKey = authInfo[idx[0][10]:idx[0][11]]
if meta.Algo == HMAC256 {
return meta, nil
return &meta, nil
}
if meta.Algo != RSA256 {
return meta, errors.Errorf(
return nil, errors.Errorf(
"invalid jwt algorithm: found %s, but supported options are HS256 or RS256", meta.Algo)
}
return meta, nil
return &meta, nil
}

func ParseAuthMeta(schema string) error {
var err error
metainfo, err = Parse(schema)
func ParseAuthMeta(schema string) (*AuthMeta, error) {
metaInfo, err := Parse(schema)
if err != nil {
return err
return nil, err
}

if metainfo.Algo != RSA256 {
return err
if metaInfo.Algo != RSA256 {
return metaInfo, nil
}

// The jwt library internally uses `bytes.IndexByte(data, '\n')` to fetch new line and fails
// if we have newline "\n" as ASCII value {92,110} instead of the actual ASCII value of 10.
// To fix this we replace "\n" with new line's ASCII value.
bytekey := bytes.ReplaceAll([]byte(metainfo.VerificationKey), []byte{92, 110}, []byte{10})
bytekey := bytes.ReplaceAll([]byte(metaInfo.VerificationKey), []byte{92, 110}, []byte{10})

metainfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey)
return err
if metaInfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey); err != nil {
return nil, err
}
return metaInfo, nil
}

func GetHeader() string {
return metainfo.Header
authMeta.RLock()
defer authMeta.RUnlock()
return authMeta.Header
}

func GetAuthMeta() AuthMeta {
return metainfo
func GetAuthMeta() *AuthMeta {
authMeta.RLock()
defer authMeta.RUnlock()
return authMeta
}

func SetAuthMeta(m *AuthMeta) {
authMeta.Lock()
defer authMeta.Unlock()

authMeta.VerificationKey = m.VerificationKey
authMeta.RSAPublicKey = m.RSAPublicKey
authMeta.Header = m.Header
authMeta.Namespace = m.Namespace
authMeta.Algo = m.Algo
authMeta.Audience = m.Audience
}

// AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata.
func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context {
authorizationJwt := r.Header.Get(metainfo.Header)
authorizationJwt := r.Header.Get(authMeta.Header)
if authorizationJwt == "" {
return ctx
}
Expand Down Expand Up @@ -197,7 +216,7 @@ func (c *CustomClaims) UnmarshalJSON(data []byte) error {
}

// Unmarshal the auth variables for a particular namespace.
if authValue, ok := result[metainfo.Namespace]; ok {
if authValue, ok := result[authMeta.Namespace]; ok {
if authJson, ok := authValue.(string); ok {
if err := json.Unmarshal([]byte(authJson), &c.AuthVariables); err != nil {
return err
Expand All @@ -216,13 +235,13 @@ func (c *CustomClaims) validateAudience() error {
}

// If there is an audience claim, but no value provided, fail
if metainfo.Audience == nil {
if authMeta.Audience == nil {
return fmt.Errorf("audience value was expected but not provided")
}

var match = false
for _, audStr := range c.Audience {
for _, expectedAudStr := range metainfo.Audience {
for _, expectedAudStr := range authMeta.Audience {
if subtle.ConstantTimeCompare([]byte(audStr), []byte(expectedAudStr)) == 1 {
match = true
break
Expand Down Expand Up @@ -252,7 +271,10 @@ func ExtractCustomClaims(ctx context.Context) (*CustomClaims, error) {
}

func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) {
if metainfo.Algo == "" {
authMeta.RLock()
defer authMeta.RUnlock()

if authMeta.Algo == "" {
return nil, fmt.Errorf(
"jwt token cannot be validated because verification algorithm is not set")
}
Expand All @@ -263,17 +285,17 @@ func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) {
token, err :=
jwt.ParseWithClaims(jwtStr, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
algo, _ := token.Header["alg"].(string)
if algo != metainfo.Algo {
if algo != authMeta.Algo {
return nil, errors.Errorf("unexpected signing method: Expected %s Found %s",
metainfo.Algo, algo)
authMeta.Algo, algo)
}
if algo == HMAC256 {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok {
return []byte(metainfo.VerificationKey), nil
return []byte(authMeta.VerificationKey), nil
}
} else if algo == RSA256 {
if _, ok := token.Method.(*jwt.SigningMethodRSA); ok {
return metainfo.RSAPublicKey, nil
return authMeta.RSAPublicKey, nil
}
}
return nil, errors.Errorf("couldn't parse signing method from token header: %s", algo)
Expand Down
1 change: 1 addition & 0 deletions graphql/bench/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func getAuthMeta(schema string) *testutil.AuthMeta {
if err != nil {
panic(err)
}
authorization.SetAuthMeta(authMeta)

return &testutil.AuthMeta{
PublicKey: authMeta.VerificationKey,
Expand Down
6 changes: 1 addition & 5 deletions graphql/e2e/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1256,11 +1256,7 @@ func TestMain(m *testing.M) {
panic(err)
}

authMeta, err := authorization.Parse(string(authSchema))
if err != nil {
panic(err)
}

authMeta := testutil.SetAuthMeta(string(authSchema))
metaInfo = &testutil.AuthMeta{
PublicKey: authMeta.VerificationKey,
Namespace: authMeta.Namespace,
Expand Down
3 changes: 2 additions & 1 deletion graphql/e2e/common/fragment.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ func fragmentInMutation(t *testing.T) {
gqlResponse := addStarshipParams.ExecuteAsPost(t, graphqlURL)
RequireNoGQLErrors(t, gqlResponse)

addStarshipExpected := `{"addStarship":{
addStarshipExpected := `
{"addStarship":{
"starship":[{
"name":"Millennium Falcon",
"length":2
Expand Down
Loading