Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(GraphQL): Add GraphQL schema validation Endpoint. #6250

Merged
merged 11 commits into from
Sep 1, 2020
23 changes: 23 additions & 0 deletions dgraph/cmd/alpha/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,29 @@ func setupServer(closer *y.Closer) {
adminSchemaHandler(w, r, adminServer)
})))

http.Handle("/admin/schema/validate", http.HandlerFunc(func(w http.ResponseWriter,
r *http.Request) {
schema := readRequest(w, r)
if schema == nil {
w.WriteHeader(http.StatusOK)
return
}

w.Header().Set("Content-Type", "application/json")

err := admin.SchemaValidate(string(schema))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
// There could be multiple errors, so replace the newline with whitespace since the newline is an
// invalid JSON character.
errStr := strings.ReplaceAll(err.Error(), "\n", " ")
x.Check2(w.Write([]byte(fmt.Sprintf(`{"status":"invalid", "error" : "%s"}`, errStr))))
return
}
w.WriteHeader(http.StatusOK)
x.Check2(w.Write([]byte(`{"status":"valid"}`)))
}))

http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true},
adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
shutDownHandler(w, r, adminServer)
Expand Down
14 changes: 13 additions & 1 deletion graphql/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,18 @@ var (
mainHealthStore = &GraphQLHealthStore{}
)

func SchemaValidate(sch string) error {
schHandler, err := schema.NewHandler(sch, true)
if err != nil {
return err
}

if _, err := schema.FromString(schHandler.GQLSchema()); err != nil {
return err
}
return nil
}

// GraphQLHealth is used to report the health status of a GraphQL server.
// It is required for kubernetes probing.
type GraphQLHealth struct {
Expand Down Expand Up @@ -583,7 +595,7 @@ func getCurrentGraphQLSchema() (*gqlSchema, error) {
}

func generateGQLSchema(sch *gqlSchema) (schema.Schema, error) {
schHandler, err := schema.NewHandler(sch.Schema)
schHandler, err := schema.NewHandler(sch.Schema, false)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion graphql/admin/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func resolveUpdateGQLSchema(ctx context.Context, m schema.Mutation) (*resolve.Re
return resolve.EmptyResult(m, err), false
}

schHandler, err := schema.NewHandler(input.Set.Schema)
schHandler, err := schema.NewHandler(input.Set.Schema, false)
if err != nil {
return resolve.EmptyResult(m, err), false
}
Expand Down
64 changes: 43 additions & 21 deletions graphql/authorization/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"regexp"
"strings"
"sync"

"github.com/vektah/gqlparser/v2/gqlerror"

Expand All @@ -46,7 +47,7 @@ const (
)

var (
metainfo AuthMeta
authMeta AuthMeta
)

type AuthMeta struct {
Expand All @@ -56,6 +57,7 @@ type AuthMeta struct {
Namespace string
Algo string
Audience []string
sync.RWMutex
}

// Validate required fields.
Expand Down Expand Up @@ -135,37 +137,54 @@ func Parse(schema string) (AuthMeta, error) {
return meta, nil
}

func ParseAuthMeta(schema string) error {
var err error
metainfo, err = Parse(schema)
func ParseAuthMeta(schema string) (*AuthMeta, error) {
metaInfo, err := Parse(schema)
if err != nil {
return err
return nil, err
}

if metainfo.Algo != RSA256 {
return err
if metaInfo.Algo != RSA256 {
return nil, err
}

// The jwt library internally uses `bytes.IndexByte(data, '\n')` to fetch new line and fails
// if we have newline "\n" as ASCII value {92,110} instead of the actual ASCII value of 10.
// To fix this we replace "\n" with new line's ASCII value.
bytekey := bytes.ReplaceAll([]byte(metainfo.VerificationKey), []byte{92, 110}, []byte{10})
bytekey := bytes.ReplaceAll([]byte(metaInfo.VerificationKey), []byte{92, 110}, []byte{10})

metainfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey)
return err
if metaInfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey); err != nil {
return nil, err
}
return &metaInfo, nil
}

func GetHeader() string {
return metainfo.Header
authMeta.RLock()
defer authMeta.RUnlock()
return authMeta.Header
}

func GetAuthMeta() AuthMeta {
return metainfo
authMeta.RLock()
defer authMeta.RUnlock()
return authMeta
}

func SetAuthMeta(m AuthMeta) {
authMeta.Lock()
defer authMeta.Unlock()

authMeta.VerificationKey = m.VerificationKey
authMeta.RSAPublicKey = m.RSAPublicKey
authMeta.Header = m.Header
authMeta.Namespace = m.Namespace
authMeta.Algo = m.Algo
authMeta.Audience = m.Audience
}

// AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata.
func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context {
authorizationJwt := r.Header.Get(metainfo.Header)
authorizationJwt := r.Header.Get(authMeta.Header)
if authorizationJwt == "" {
return ctx
}
Expand Down Expand Up @@ -197,7 +216,7 @@ func (c *CustomClaims) UnmarshalJSON(data []byte) error {
}

// Unmarshal the auth variables for a particular namespace.
if authValue, ok := result[metainfo.Namespace]; ok {
if authValue, ok := result[authMeta.Namespace]; ok {
if authJson, ok := authValue.(string); ok {
if err := json.Unmarshal([]byte(authJson), &c.AuthVariables); err != nil {
return err
Expand All @@ -216,13 +235,13 @@ func (c *CustomClaims) validateAudience() error {
}

// If there is an audience claim, but no value provided, fail
if metainfo.Audience == nil {
if authMeta.Audience == nil {
return fmt.Errorf("audience value was expected but not provided")
}

var match = false
for _, audStr := range c.Audience {
for _, expectedAudStr := range metainfo.Audience {
for _, expectedAudStr := range authMeta.Audience {
if subtle.ConstantTimeCompare([]byte(audStr), []byte(expectedAudStr)) == 1 {
match = true
break
Expand Down Expand Up @@ -252,7 +271,10 @@ func ExtractCustomClaims(ctx context.Context) (*CustomClaims, error) {
}

func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) {
if metainfo.Algo == "" {
authMeta.RLock()
defer authMeta.RUnlock()

if authMeta.Algo == "" {
return nil, fmt.Errorf(
"jwt token cannot be validated because verification algorithm is not set")
}
Expand All @@ -263,17 +285,17 @@ func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) {
token, err :=
jwt.ParseWithClaims(jwtStr, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
algo, _ := token.Header["alg"].(string)
if algo != metainfo.Algo {
if algo != authMeta.Algo {
return nil, errors.Errorf("unexpected signing method: Expected %s Found %s",
metainfo.Algo, algo)
authMeta.Algo, algo)
}
if algo == HMAC256 {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok {
return []byte(metainfo.VerificationKey), nil
return []byte(authMeta.VerificationKey), nil
}
} else if algo == RSA256 {
if _, ok := token.Method.(*jwt.SigningMethodRSA); ok {
return metainfo.RSAPublicKey, nil
return authMeta.RSAPublicKey, nil
}
}
return nil, errors.Errorf("couldn't parse signing method from token header: %s", algo)
Expand Down
1 change: 1 addition & 0 deletions graphql/bench/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func getAuthMeta(schema string) *testutil.AuthMeta {
if err != nil {
panic(err)
}
authorization.SetAuthMeta(authMeta)

return &testutil.AuthMeta{
PublicKey: authMeta.VerificationKey,
Expand Down
1 change: 1 addition & 0 deletions graphql/e2e/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,7 @@ func TestMain(m *testing.M) {
if err != nil {
panic(err)
}
authorization.SetAuthMeta(authMeta)

metaInfo = &testutil.AuthMeta{
PublicKey: authMeta.VerificationKey,
Expand Down
16 changes: 8 additions & 8 deletions graphql/e2e/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ import (
)

const (
graphqlURL = "http://localhost:8180/graphql"
graphqlAdminURL = "http://localhost:8180/admin"
GraphqlURL = "http://localhost:8180/graphql"
GraphqlAdminURL = "http://localhost:8180/admin"
AlphagRPC = "localhost:9180"

adminDgraphHealthURL = "http://localhost:8280/health?all"
Expand Down Expand Up @@ -173,7 +173,7 @@ type student struct {
}

func BootstrapServer(schema, data []byte) {
err := checkGraphQLStarted(graphqlAdminURL)
err := checkGraphQLStarted(GraphqlAdminURL)
if err != nil {
x.Panic(errors.Errorf(
"Waited for GraphQL test server to become available, but it never did.\n"+
Expand All @@ -195,7 +195,7 @@ func BootstrapServer(schema, data []byte) {
}
client := dgo.NewDgraphClient(api.NewDgraphClient(d))

err = addSchema(graphqlAdminURL, string(schema))
err = addSchema(GraphqlAdminURL, string(schema))
if err != nil {
x.Panic(err)
}
Expand Down Expand Up @@ -371,7 +371,7 @@ func gzipCompressionHeader(t *testing.T) {
}`,
}

req, err := queryCountry.createGQLPost(graphqlURL)
req, err := queryCountry.createGQLPost(GraphqlURL)
require.NoError(t, err)

req.Header.Set("Content-Encoding", "gzip")
Expand All @@ -398,7 +398,7 @@ func gzipCompressionNoHeader(t *testing.T) {
gzipEncoding: true,
}

req, err := queryCountry.createGQLPost(graphqlURL)
req, err := queryCountry.createGQLPost(GraphqlURL)
require.NoError(t, err)

req.Header.Del("Content-Encoding")
Expand All @@ -424,7 +424,7 @@ func getQueryEmptyVariable(t *testing.T) {
}
}`,
}
req, err := queryCountry.createGQLGet(graphqlURL)
req, err := queryCountry.createGQLGet(GraphqlURL)
require.NoError(t, err)

q := req.URL.Query()
Expand Down Expand Up @@ -634,7 +634,7 @@ func allCountriesAdded() ([]*country, error) {
return nil, errors.Wrap(err, "unable to build GraphQL query")
}

req, err := http.NewRequest("POST", graphqlURL, bytes.NewBuffer(body))
req, err := http.NewRequest("POST", GraphqlURL, bytes.NewBuffer(body))
if err != nil {
return nil, errors.Wrap(err, "unable to build GraphQL request")
}
Expand Down
6 changes: 3 additions & 3 deletions graphql/e2e/common/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func graphQLCompletionOn(t *testing.T) {
}

// Check that the error is valid
gqlResponse := queryCountry.ExecuteAsPost(t, graphqlURL)
gqlResponse := queryCountry.ExecuteAsPost(t, GraphqlURL)
require.NotNil(t, gqlResponse.Errors)
require.Equal(t, 1, len(gqlResponse.Errors))
require.Contains(t, gqlResponse.Errors[0].Error(),
Expand Down Expand Up @@ -166,7 +166,7 @@ func deepMutationErrors(t *testing.T) {
},
}

gqlResponse := executeRequest(t, graphqlURL, updateCountryParams)
gqlResponse := executeRequest(t, GraphqlURL, updateCountryParams)
require.NotNil(t, gqlResponse.Errors)
require.Equal(t, 1, len(gqlResponse.Errors))
require.EqualError(t, gqlResponse.Errors[0], tcase.exp)
Expand All @@ -192,7 +192,7 @@ func requestValidationErrors(t *testing.T) {
Query: tcase.GQLRequest,
Variables: tcase.variables,
}
gqlResponse := test.ExecuteAsPost(t, graphqlURL)
gqlResponse := test.ExecuteAsPost(t, GraphqlURL)

require.Nil(t, gqlResponse.Data)
if diff := cmp.Diff(tcase.Errors, gqlResponse.Errors); diff != "" {
Expand Down
8 changes: 4 additions & 4 deletions graphql/e2e/common/fragment.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func fragmentInMutation(t *testing.T) {
}},
}

gqlResponse := addStarshipParams.ExecuteAsPost(t, graphqlURL)
gqlResponse := addStarshipParams.ExecuteAsPost(t, GraphqlURL)
RequireNoGQLErrors(t, gqlResponse)

addStarshipExpected := fmt.Sprintf(`{"addStarship":{
Expand Down Expand Up @@ -83,7 +83,7 @@ func fragmentInQuery(t *testing.T) {
},
}

gqlResponse := queryStarshipParams.ExecuteAsPost(t, graphqlURL)
gqlResponse := queryStarshipParams.ExecuteAsPost(t, GraphqlURL)
RequireNoGQLErrors(t, gqlResponse)

queryStarshipExpected := fmt.Sprintf(`
Expand Down Expand Up @@ -152,7 +152,7 @@ func fragmentInQueryOnInterface(t *testing.T) {
`,
}

gqlResponse := queryCharacterParams.ExecuteAsPost(t, graphqlURL)
gqlResponse := queryCharacterParams.ExecuteAsPost(t, GraphqlURL)
RequireNoGQLErrors(t, gqlResponse)

queryCharacterExpected := fmt.Sprintf(`
Expand Down Expand Up @@ -227,7 +227,7 @@ func fragmentInQueryOnObject(t *testing.T) {
`,
}

gqlResponse := queryHumanParams.ExecuteAsPost(t, graphqlURL)
gqlResponse := queryHumanParams.ExecuteAsPost(t, GraphqlURL)
RequireNoGQLErrors(t, gqlResponse)

queryCharacterExpected := fmt.Sprintf(`
Expand Down
Loading