From 6d56b0f6426e9daaaf4552b6ae36d565e7bf2ce3 Mon Sep 17 00:00:00 2001
From: Arijit Das <arijit@dgraph.io>
Date: Tue, 21 Apr 2020 21:23:45 +0530
Subject: [PATCH] Send auth variable in custom jwt token. (#5220)

* Send auth variable in custom jwt token.

* Verify custom claims using key.
---
 graphql/resolve/auth.go           | 100 ++++++++++++++++++++++++++++++
 graphql/resolve/auth_test.go      |  16 ++---
 graphql/resolve/query_rewriter.go |   6 +-
 graphql/web/http.go               |   1 +
 4 files changed, 110 insertions(+), 13 deletions(-)
 create mode 100644 graphql/resolve/auth.go

diff --git a/graphql/resolve/auth.go b/graphql/resolve/auth.go
new file mode 100644
index 00000000000..a8270771976
--- /dev/null
+++ b/graphql/resolve/auth.go
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2020 Dgraph Labs, Inc. and Contributors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package resolve
+
+import (
+	"context"
+	"fmt"
+	"github.com/dgrijalva/jwt-go"
+	"github.com/pkg/errors"
+	"google.golang.org/grpc/metadata"
+	"net/http"
+	"time"
+)
+
+//TODO: Get the secret key dynamically.
+const (
+	AuthJwtCtxKey  = "authorizationJwt"
+	AuthHmacSecret = "Secretkey"
+)
+
+// AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata.
+func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context {
+	authorizationJwt := r.Header.Get("X-Dgraph-AuthorizationToken")
+	if authorizationJwt == "" {
+		return ctx
+	}
+
+	md, ok := metadata.FromIncomingContext(ctx)
+	if !ok {
+		md = metadata.New(nil)
+	}
+
+	md.Append(AuthJwtCtxKey, authorizationJwt)
+	ctx = metadata.NewIncomingContext(ctx, md)
+	return ctx
+}
+
+type CustomClaims struct {
+	AuthVariables map[string]interface{} `json:"https://dgraph.io/jwt/claims"`
+	jwt.StandardClaims
+}
+
+func ExtractAuthVariables(ctx context.Context) (map[string]interface{}, error) {
+	// Extract the jwt and unmarshal the jwt to get the auth variables.
+	md, ok := metadata.FromIncomingContext(ctx)
+	if !ok {
+		return nil, nil
+	}
+
+	jwtToken := md.Get(AuthJwtCtxKey)
+	if len(jwtToken) == 0 {
+		return nil, nil
+	} else if len(jwtToken) > 1 {
+		return nil, fmt.Errorf("invalid jwt auth token")
+	}
+
+	return validateToken(jwtToken[0])
+}
+
+func validateToken(jwtStr string) (map[string]interface{}, error) {
+	token, err :=
+		jwt.ParseWithClaims(jwtStr, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
+			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+				return nil, errors.Errorf("unexpected signing method: %v", token.Header["alg"])
+			}
+			return []byte(AuthHmacSecret), nil
+		})
+
+	if err != nil {
+		return nil, errors.Errorf("unable to parse jwt token:%v", err)
+	}
+
+	claims, ok := token.Claims.(*CustomClaims)
+	if !ok || !token.Valid {
+		return nil, errors.Errorf("claims in jwt token is not map claims")
+	}
+
+	// by default, the MapClaims.Valid will return true if the exp field is not set
+	// here we enforce the checking to make sure that the refresh token has not expired
+	now := time.Now().Unix()
+	if !claims.VerifyExpiresAt(now, true) {
+		return nil, errors.Errorf("Token is expired") // the same error msg that's used inside jwt-go
+	}
+
+	return claims.AuthVariables, nil
+}
diff --git a/graphql/resolve/auth_test.go b/graphql/resolve/auth_test.go
index 405d34be662..ab5928512d7 100644
--- a/graphql/resolve/auth_test.go
+++ b/graphql/resolve/auth_test.go
@@ -20,6 +20,7 @@ import (
 	"context"
 	"io/ioutil"
 	"testing"
+	"time"
 
 	"github.com/dgraph-io/dgraph/graphql/dgraph"
 	"github.com/dgraph-io/dgraph/graphql/schema"
@@ -52,20 +53,15 @@ func TestAuthQueryRewriting(t *testing.T) {
 
 	testRewriter := NewQueryRewriter()
 
-	type MyCustomClaims struct {
-		Foo map[string]interface{} `json:"https://dgraph.io/jwt/claims"`
-		jwt.StandardClaims
-	}
-
 	// Create the Claims
-	claims := MyCustomClaims{
+	claims := CustomClaims{
 		map[string]interface{}{},
 		jwt.StandardClaims{
-			ExpiresAt: 15000,
+			ExpiresAt: time.Now().Add(time.Minute).Unix(),
 			Issuer:    "test",
 		},
 	}
-	claims.Foo["User"] = "user1"
+	claims.AuthVariables["USER"] = "user1"
 
 	gqlSchema := test.LoadSchemaFromFile(t, "../e2e/auth/schema.graphql")
 	for _, tcase := range tests {
@@ -80,9 +76,9 @@ func TestAuthQueryRewriting(t *testing.T) {
 			gqlQuery := test.GetQuery(t, op)
 
 			ctx := context.Background()
-			claims.Foo["Role"] = tcase.Role
+			claims.AuthVariables["ROLE"] = tcase.Role
 			token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-			ss, err := token.SignedString([]byte("Secret"))
+			ss, err := token.SignedString([]byte(AuthHmacSecret))
 			require.NoError(t, err)
 
 			md := metadata.New(nil)
diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go
index dd34c81e3b3..903ba08d1a4 100644
--- a/graphql/resolve/query_rewriter.go
+++ b/graphql/resolve/query_rewriter.go
@@ -47,9 +47,9 @@ func (qr *queryRewriter) Rewrite(
 	ctx context.Context,
 	gqlQuery schema.Query) (*gql.GraphQuery, error) {
 
-	// FIXME: should come from the JWT
-	authVariables := map[string]interface{}{
-		"USER": "user1",
+	authVariables, err := ExtractAuthVariables(ctx)
+	if err != nil {
+		return nil, err
 	}
 
 	auth := &authRewriter{
diff --git a/graphql/web/http.go b/graphql/web/http.go
index 6d7997624ff..b1214c44095 100644
--- a/graphql/web/http.go
+++ b/graphql/web/http.go
@@ -152,6 +152,7 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		x.Panic(errors.New("graphqlHandler not initialised"))
 	}
 
+	ctx = resolve.AttachAuthorizationJwt(ctx, r)
 	ctx = x.AttachAccessJwt(ctx, r)
 
 	if ip, port, err := net.SplitHostPort(r.RemoteAddr); err == nil {