Skip to content

Commit

Permalink
feat(GraphQL): Add GraphQL schema validation Endpoint. (#6250)
Browse files Browse the repository at this point in the history
* Add GraphQL schema validation Endpoint.
  • Loading branch information
Arijit Das authored Sep 1, 2020
1 parent 0f3bfb1 commit df1c7c9
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 80 deletions.
17 changes: 17 additions & 0 deletions dgraph/cmd/alpha/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,23 @@ func setupServer(closer *y.Closer) {
adminSchemaHandler(w, r, adminServer)
})))

http.Handle("/admin/schema/validate", http.HandlerFunc(func(w http.ResponseWriter,
r *http.Request) {
schema := readRequest(w, r)
w.Header().Set("Content-Type", "application/json")

err := admin.SchemaValidate(string(schema))
if err == nil {
w.WriteHeader(http.StatusOK)
x.SetStatus(w, "success", "Schema is valid")
return
}

w.WriteHeader(http.StatusBadRequest)
errs := strings.Split(strings.TrimSpace(err.Error()), "\n")
x.SetStatusWithErrors(w, x.ErrorInvalidRequest, errs)
}))

http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true},
adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
shutDownHandler(w, r, adminServer)
Expand Down
12 changes: 11 additions & 1 deletion graphql/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,16 @@ var (
mainHealthStore = &GraphQLHealthStore{}
)

func SchemaValidate(sch string) error {
schHandler, err := schema.NewHandler(sch, true)
if err != nil {
return err
}

_, err = schema.FromString(schHandler.GQLSchema())
return err
}

// GraphQLHealth is used to report the health status of a GraphQL server.
// It is required for kubernetes probing.
type GraphQLHealth struct {
Expand Down Expand Up @@ -602,7 +612,7 @@ func getCurrentGraphQLSchema() (*gqlSchema, error) {
}

func generateGQLSchema(sch *gqlSchema) (schema.Schema, error) {
schHandler, err := schema.NewHandler(sch.Schema)
schHandler, err := schema.NewHandler(sch.Schema, false)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion graphql/admin/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ func resolveUpdateGQLSchema(ctx context.Context, m schema.Mutation) (*resolve.Re
return resolve.EmptyResult(m, err), false
}

schHandler, err := schema.NewHandler(input.Set.Schema)
// We just need to validate the schema. Schema is later set in `resetSchema()` when the schema
// is returned from badger.
schHandler, err := schema.NewHandler(input.Set.Schema, true)
if err != nil {
return resolve.EmptyResult(m, err), false
}
Expand Down
82 changes: 52 additions & 30 deletions graphql/authorization/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"regexp"
"strings"
"sync"

"github.com/vektah/gqlparser/v2/gqlerror"

Expand All @@ -46,7 +47,7 @@ const (
)

var (
metainfo AuthMeta
authMeta = &AuthMeta{}
)

type AuthMeta struct {
Expand All @@ -56,6 +57,7 @@ type AuthMeta struct {
Namespace string
Algo string
Audience []string
sync.RWMutex
}

// Validate required fields.
Expand Down Expand Up @@ -83,17 +85,17 @@ func (a *AuthMeta) validate() error {
return nil
}

func Parse(schema string) (AuthMeta, error) {
func Parse(schema string) (*AuthMeta, error) {
var meta AuthMeta
authInfoIdx := strings.LastIndex(schema, AuthMetaHeader)
if authInfoIdx == -1 {
return meta, nil
return nil, nil
}
authInfo := schema[authInfoIdx:]

err := json.Unmarshal([]byte(authInfo[len(AuthMetaHeader):]), &meta)
if err == nil {
return meta, meta.validate()
return &meta, meta.validate()
}

fmt.Println("Falling back to parsing `Dgraph.Authorization` in old format." +
Expand All @@ -112,60 +114,77 @@ func Parse(schema string) (AuthMeta, error) {
authMetaRegex, err :=
regexp.Compile(`^#[\s]([^\s]+)[\s]+([^\s]+)[\s]+([^\s]+)[\s]+([^\s]+)[\s]+"([^\"]+)"`)
if err != nil {
return meta, gqlerror.Errorf("JWT parsing failed: %v", err)
return nil, gqlerror.Errorf("JWT parsing failed: %v", err)
}

idx := authMetaRegex.FindAllStringSubmatchIndex(authInfo, -1)
if len(idx) != 1 || len(idx[0]) != 12 ||
!strings.HasPrefix(authInfo, authInfo[idx[0][0]:idx[0][1]]) {
return meta, gqlerror.Errorf("Invalid `Dgraph.Authorization` format: %s", authInfo)
return nil, gqlerror.Errorf("Invalid `Dgraph.Authorization` format: %s", authInfo)
}

meta.Header = authInfo[idx[0][4]:idx[0][5]]
meta.Namespace = authInfo[idx[0][6]:idx[0][7]]
meta.Algo = authInfo[idx[0][8]:idx[0][9]]
meta.VerificationKey = authInfo[idx[0][10]:idx[0][11]]
if meta.Algo == HMAC256 {
return meta, nil
return &meta, nil
}
if meta.Algo != RSA256 {
return meta, errors.Errorf(
return nil, errors.Errorf(
"invalid jwt algorithm: found %s, but supported options are HS256 or RS256", meta.Algo)
}
return meta, nil
return &meta, nil
}

func ParseAuthMeta(schema string) error {
var err error
metainfo, err = Parse(schema)
func ParseAuthMeta(schema string) (*AuthMeta, error) {
metaInfo, err := Parse(schema)
if err != nil {
return err
return nil, err
}

if metainfo.Algo != RSA256 {
return err
if metaInfo.Algo != RSA256 {
return metaInfo, nil
}

// The jwt library internally uses `bytes.IndexByte(data, '\n')` to fetch new line and fails
// if we have newline "\n" as ASCII value {92,110} instead of the actual ASCII value of 10.
// To fix this we replace "\n" with new line's ASCII value.
bytekey := bytes.ReplaceAll([]byte(metainfo.VerificationKey), []byte{92, 110}, []byte{10})
bytekey := bytes.ReplaceAll([]byte(metaInfo.VerificationKey), []byte{92, 110}, []byte{10})

metainfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey)
return err
if metaInfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey); err != nil {
return nil, err
}
return metaInfo, nil
}

func GetHeader() string {
return metainfo.Header
authMeta.RLock()
defer authMeta.RUnlock()
return authMeta.Header
}

func GetAuthMeta() AuthMeta {
return metainfo
func GetAuthMeta() *AuthMeta {
authMeta.RLock()
defer authMeta.RUnlock()
return authMeta
}

func SetAuthMeta(m *AuthMeta) {
authMeta.Lock()
defer authMeta.Unlock()

authMeta.VerificationKey = m.VerificationKey
authMeta.RSAPublicKey = m.RSAPublicKey
authMeta.Header = m.Header
authMeta.Namespace = m.Namespace
authMeta.Algo = m.Algo
authMeta.Audience = m.Audience
}

// AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata.
func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context {
authorizationJwt := r.Header.Get(metainfo.Header)
authorizationJwt := r.Header.Get(authMeta.Header)
if authorizationJwt == "" {
return ctx
}
Expand Down Expand Up @@ -197,7 +216,7 @@ func (c *CustomClaims) UnmarshalJSON(data []byte) error {
}

// Unmarshal the auth variables for a particular namespace.
if authValue, ok := result[metainfo.Namespace]; ok {
if authValue, ok := result[authMeta.Namespace]; ok {
if authJson, ok := authValue.(string); ok {
if err := json.Unmarshal([]byte(authJson), &c.AuthVariables); err != nil {
return err
Expand All @@ -216,13 +235,13 @@ func (c *CustomClaims) validateAudience() error {
}

// If there is an audience claim, but no value provided, fail
if metainfo.Audience == nil {
if authMeta.Audience == nil {
return fmt.Errorf("audience value was expected but not provided")
}

var match = false
for _, audStr := range c.Audience {
for _, expectedAudStr := range metainfo.Audience {
for _, expectedAudStr := range authMeta.Audience {
if subtle.ConstantTimeCompare([]byte(audStr), []byte(expectedAudStr)) == 1 {
match = true
break
Expand Down Expand Up @@ -252,7 +271,10 @@ func ExtractCustomClaims(ctx context.Context) (*CustomClaims, error) {
}

func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) {
if metainfo.Algo == "" {
authMeta.RLock()
defer authMeta.RUnlock()

if authMeta.Algo == "" {
return nil, fmt.Errorf(
"jwt token cannot be validated because verification algorithm is not set")
}
Expand All @@ -263,17 +285,17 @@ func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) {
token, err :=
jwt.ParseWithClaims(jwtStr, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
algo, _ := token.Header["alg"].(string)
if algo != metainfo.Algo {
if algo != authMeta.Algo {
return nil, errors.Errorf("unexpected signing method: Expected %s Found %s",
metainfo.Algo, algo)
authMeta.Algo, algo)
}
if algo == HMAC256 {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok {
return []byte(metainfo.VerificationKey), nil
return []byte(authMeta.VerificationKey), nil
}
} else if algo == RSA256 {
if _, ok := token.Method.(*jwt.SigningMethodRSA); ok {
return metainfo.RSAPublicKey, nil
return authMeta.RSAPublicKey, nil
}
}
return nil, errors.Errorf("couldn't parse signing method from token header: %s", algo)
Expand Down
1 change: 1 addition & 0 deletions graphql/bench/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func getAuthMeta(schema string) *testutil.AuthMeta {
if err != nil {
panic(err)
}
authorization.SetAuthMeta(authMeta)

return &testutil.AuthMeta{
PublicKey: authMeta.VerificationKey,
Expand Down
6 changes: 1 addition & 5 deletions graphql/e2e/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1256,11 +1256,7 @@ func TestMain(m *testing.M) {
panic(err)
}

authMeta, err := authorization.Parse(string(authSchema))
if err != nil {
panic(err)
}

authMeta := testutil.SetAuthMeta(string(authSchema))
metaInfo = &testutil.AuthMeta{
PublicKey: authMeta.VerificationKey,
Namespace: authMeta.Namespace,
Expand Down
3 changes: 2 additions & 1 deletion graphql/e2e/common/fragment.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ func fragmentInMutation(t *testing.T) {
gqlResponse := addStarshipParams.ExecuteAsPost(t, graphqlURL)
RequireNoGQLErrors(t, gqlResponse)

addStarshipExpected := `{"addStarship":{
addStarshipExpected := `
{"addStarship":{
"starship":[{
"name":"Millennium Falcon",
"length":2
Expand Down
Loading

0 comments on commit df1c7c9

Please sign in to comment.