Skip to content

Commit

Permalink
Send auth variable in custom jwt token. (#5220)
Browse files Browse the repository at this point in the history
* Send auth variable in custom jwt token.

* Verify custom claims using key.
  • Loading branch information
Arijit Das authored Apr 21, 2020
1 parent 46d6beb commit 6d56b0f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 13 deletions.
100 changes: 100 additions & 0 deletions graphql/resolve/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright 2020 Dgraph Labs, Inc. and Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package resolve

import (
"context"
"fmt"
"github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
"google.golang.org/grpc/metadata"
"net/http"
"time"
)

//TODO: Get the secret key dynamically.
const (
AuthJwtCtxKey = "authorizationJwt"
AuthHmacSecret = "Secretkey"
)

// AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata.
func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context {
authorizationJwt := r.Header.Get("X-Dgraph-AuthorizationToken")
if authorizationJwt == "" {
return ctx
}

md, ok := metadata.FromIncomingContext(ctx)
if !ok {
md = metadata.New(nil)
}

md.Append(AuthJwtCtxKey, authorizationJwt)
ctx = metadata.NewIncomingContext(ctx, md)
return ctx
}

type CustomClaims struct {
AuthVariables map[string]interface{} `json:"https://dgraph.io/jwt/claims"`
jwt.StandardClaims
}

func ExtractAuthVariables(ctx context.Context) (map[string]interface{}, error) {
// Extract the jwt and unmarshal the jwt to get the auth variables.
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, nil
}

jwtToken := md.Get(AuthJwtCtxKey)
if len(jwtToken) == 0 {
return nil, nil
} else if len(jwtToken) > 1 {
return nil, fmt.Errorf("invalid jwt auth token")
}

return validateToken(jwtToken[0])
}

func validateToken(jwtStr string) (map[string]interface{}, error) {
token, err :=
jwt.ParseWithClaims(jwtStr, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(AuthHmacSecret), nil
})

if err != nil {
return nil, errors.Errorf("unable to parse jwt token:%v", err)
}

claims, ok := token.Claims.(*CustomClaims)
if !ok || !token.Valid {
return nil, errors.Errorf("claims in jwt token is not map claims")
}

// by default, the MapClaims.Valid will return true if the exp field is not set
// here we enforce the checking to make sure that the refresh token has not expired
now := time.Now().Unix()
if !claims.VerifyExpiresAt(now, true) {
return nil, errors.Errorf("Token is expired") // the same error msg that's used inside jwt-go
}

return claims.AuthVariables, nil
}
16 changes: 6 additions & 10 deletions graphql/resolve/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"io/ioutil"
"testing"
"time"

"github.com/dgraph-io/dgraph/graphql/dgraph"
"github.com/dgraph-io/dgraph/graphql/schema"
Expand Down Expand Up @@ -52,20 +53,15 @@ func TestAuthQueryRewriting(t *testing.T) {

testRewriter := NewQueryRewriter()

type MyCustomClaims struct {
Foo map[string]interface{} `json:"https://dgraph.io/jwt/claims"`
jwt.StandardClaims
}

// Create the Claims
claims := MyCustomClaims{
claims := CustomClaims{
map[string]interface{}{},
jwt.StandardClaims{
ExpiresAt: 15000,
ExpiresAt: time.Now().Add(time.Minute).Unix(),
Issuer: "test",
},
}
claims.Foo["User"] = "user1"
claims.AuthVariables["USER"] = "user1"

gqlSchema := test.LoadSchemaFromFile(t, "../e2e/auth/schema.graphql")
for _, tcase := range tests {
Expand All @@ -80,9 +76,9 @@ func TestAuthQueryRewriting(t *testing.T) {
gqlQuery := test.GetQuery(t, op)

ctx := context.Background()
claims.Foo["Role"] = tcase.Role
claims.AuthVariables["ROLE"] = tcase.Role
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, err := token.SignedString([]byte("Secret"))
ss, err := token.SignedString([]byte(AuthHmacSecret))
require.NoError(t, err)

md := metadata.New(nil)
Expand Down
6 changes: 3 additions & 3 deletions graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func (qr *queryRewriter) Rewrite(
ctx context.Context,
gqlQuery schema.Query) (*gql.GraphQuery, error) {

// FIXME: should come from the JWT
authVariables := map[string]interface{}{
"USER": "user1",
authVariables, err := ExtractAuthVariables(ctx)
if err != nil {
return nil, err
}

auth := &authRewriter{
Expand Down
1 change: 1 addition & 0 deletions graphql/web/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
x.Panic(errors.New("graphqlHandler not initialised"))
}

ctx = resolve.AttachAuthorizationJwt(ctx, r)
ctx = x.AttachAccessJwt(ctx, r)

if ip, port, err := net.SplitHostPort(r.RemoteAddr); err == nil {
Expand Down

0 comments on commit 6d56b0f

Please sign in to comment.