Skip to content

Commit

Permalink
Clean up some code
Browse files Browse the repository at this point in the history
  • Loading branch information
pawanrawal committed Jul 20, 2020
1 parent 09da5b9 commit 3922649
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 39 deletions.
17 changes: 7 additions & 10 deletions graphql/authorization/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,18 @@ func GetAuthMeta() AuthMeta {
return metainfo
}

// AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata.
func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context {
// ExtractAuthJWT returns a context with the JWT attached as metadata.
func ExtractAuthJWT(r *http.Request) context.Context {
ctx := context.Background()
authorizationJwt := r.Header.Get(metainfo.Header)
if authorizationJwt == "" {
return ctx
}

md, ok := metadata.FromIncomingContext(ctx)
if !ok {
md = metadata.New(nil)
}

md.Append(string(AuthJwtCtxKey), authorizationJwt)
ctx = metadata.NewIncomingContext(ctx, md)
return ctx
md := metadata.New(map[string]string{
string(AuthJwtCtxKey): authorizationJwt,
})
return metadata.NewIncomingContext(ctx, md)
}

type CustomClaims struct {
Expand Down
5 changes: 3 additions & 2 deletions graphql/e2e/common/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ type GraphQLSubscriptionClient struct {
}

// NewGraphQLSubscription returns graphql subscription client.
func NewGraphQLSubscription(url string, req *schema.Request, subscriptionPayload string) (*GraphQLSubscriptionClient, error) {
func NewGraphQLSubscription(url string, req *schema.Request,
initPayload string) (*GraphQLSubscriptionClient, error) {
header := http.Header{
"Sec-WebSocket-Protocol": []string{protocolGraphQLWS},
}
Expand All @@ -69,7 +70,7 @@ func NewGraphQLSubscription(url string, req *schema.Request, subscriptionPayload
// Initialize subscription.
init := operationMessage{
Type: initMsg,
Payload: []byte(subscriptionPayload),
Payload: []byte(initPayload),
}

// Send Intialization message to the graphql server.
Expand Down
6 changes: 2 additions & 4 deletions graphql/resolve/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,11 @@ func TestStringCustomClaim(t *testing.T) {

customClaims, err := authorization.ExtractCustomClaims(ctx)
require.NoError(t, err)
authVar := customClaims.AuthVariables
result := map[string]interface{}{
"ROLE": "ADMIN",
"USER": "50950b40-262f-4b26-88a7-cbbb780b2176",
}
require.Equal(t, authVar, result)
require.Equal(t, customClaims.AuthVariables, result)
}

func TestAudienceClaim(t *testing.T) {
Expand Down Expand Up @@ -231,12 +230,11 @@ func TestAudienceClaim(t *testing.T) {
return
}

authVar := customClaims.AuthVariables
result := map[string]interface{}{
"ROLE": "ADMIN",
"USER": "50950b40-262f-4b26-88a7-cbbb780b2176",
}
require.Equal(t, authVar, result)
require.Equal(t, customClaims.AuthVariables, result)
})
}
}
Expand Down
3 changes: 1 addition & 2 deletions graphql/resolve/mutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,8 @@ func authorizeNewNodes(
if err != nil {
return schema.GQLWrapf(err, "authorization failed")
}
authVariables := customClaims.AuthVariables
newRw := &authRewriter{
authVariables: authVariables,
authVariables: customClaims.AuthVariables,
varGen: NewVariableGenerator(),
selector: addAuthSelector,
hasAuthRules: true,
Expand Down
9 changes: 0 additions & 9 deletions graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,6 @@ func (qr *queryRewriter) Rewrite(
}

authVariables, _ := ctx.Value(authorization.AuthVariables).(map[string]interface{})

if authVariables == nil {
customClaims, err := authorization.ExtractCustomClaims(ctx)
if err != nil {
return nil, err
}
authVariables = customClaims.AuthVariables
}

authRw := &authRewriter{
authVariables: authVariables,
varGen: NewVariableGenerator(),
Expand Down
3 changes: 2 additions & 1 deletion graphql/subscription/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ func (p *Poller) AddSubscriber(
p.Lock()
defer p.Unlock()

ctx := context.WithValue(context.Background(), authorization.AuthVariables, customClaims.AuthVariables)
ctx := context.WithValue(context.Background(), authorization.AuthVariables,
customClaims.AuthVariables)
res := p.resolver.Resolve(ctx, req)
if len(res.Errors) != 0 {
return nil, res.Errors
Expand Down
29 changes: 18 additions & 11 deletions graphql/web/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ func (gs *graphqlSubscription) Subscribe(
name := authorization.GetHeader()
val, ok := payload[name].(string)
if ok {

md := metadata.New(map[string]string{
"authorizationJwt": val,
})
Expand Down Expand Up @@ -197,20 +196,28 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
x.Panic(errors.New("graphqlHandler not initialised"))
}

ctx = authorization.AttachAuthorizationJwt(ctx, r)
ctx = x.AttachAccessJwt(ctx, r)
// Add remote addr as peer info so that the remote address can be logged
// inside Server.Login
ctx = x.AttachRemoteIP(ctx, r)

var res *schema.Response
gqlReq, err := getRequest(ctx, r)

actx := authorization.ExtractAuthJWT(r)
cc, err := authorization.ExtractCustomClaims(actx)
if err != nil {
res = schema.ErrorResponse(err)
} else {
gqlReq.Header = r.Header
res = gh.resolver.Resolve(ctx, gqlReq)
// Lets attach the auth variables to the context so tha queries can access it.
ctx = context.WithValue(ctx, authorization.AuthVariables, cc.AuthVariables)

ctx = x.AttachAccessJwt(ctx, r)
// Add remote addr as peer info so that the remote address can be logged
// inside Server.Login
ctx = x.AttachRemoteIP(ctx, r)

gqlReq, err := getRequest(ctx, r)

if err != nil {
res = schema.ErrorResponse(err)
} else {
gqlReq.Header = r.Header
res = gh.resolver.Resolve(ctx, gqlReq)
}
}

write(w, res, strings.Contains(r.Header.Get("Accept-Encoding"), "gzip"))
Expand Down

0 comments on commit 3922649

Please sign in to comment.