Skip to content

Commit

Permalink
fix: fixing audit logs for websocket connections (#8048)
Browse files Browse the repository at this point in the history
* fix: fixing audit logs for websocket connections
  • Loading branch information
aman-bansal authored Sep 24, 2021
1 parent 3103f0e commit 9792506
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 7 deletions.
1 change: 1 addition & 0 deletions ee/audit/audit_ee.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ const (
PoorManAuth = "PoorManAuth"
Grpc = "Grpc"
Http = "Http"
WebSocket = "Websocket"
)

var auditor = &auditLogger{}
Expand Down
6 changes: 6 additions & 0 deletions ee/audit/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"net/http"

"google.golang.org/grpc"

"github.com/dgraph-io/dgraph/graphql/schema"
)

func AuditRequestGRPC(ctx context.Context, req interface{},
Expand All @@ -35,3 +37,7 @@ func AuditRequestHttp(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}

func AuditWebSockets(ctx context.Context, req *schema.Request) {
return
}
49 changes: 49 additions & 0 deletions ee/audit/interceptor_ee.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
"io"
"io/ioutil"
"net"
"net/http"
"regexp"
"strconv"
Expand Down Expand Up @@ -92,6 +94,19 @@ func AuditRequestHttp(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
return
}

// Websocket connection in graphQl happens differently. We only get access tokens and
// metadata in payload later once the connection is upgraded to correct protocol.
// Doc: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.9.4/PROTOCOL.md
//
// Auditing for websocket connections will be handled by graphql/admin/http.go:154#Subscribe
for _, subprotocol := range websocket.Subprotocols(r) {
if subprotocol == "graphql-ws" {
next.ServeHTTP(w, r)
return
}
}

rw := NewResponseWriter(w)
var buf bytes.Buffer
tee := io.TeeReader(r.Body, &buf)
Expand All @@ -102,6 +117,40 @@ func AuditRequestHttp(next http.Handler) http.Handler {
})
}

func AuditWebSockets(ctx context.Context, req *schema.Request) {
if atomic.LoadUint32(&auditEnabled) == 0 {
return
}

namespace := uint64(0)
var user string
if token := req.Header.Get("X-Dgraph-AccessToken"); token != "" {
user = getUser(token, false)
namespace, _ = x.ExtractNamespaceFromJwt(token)
} else if token := req.Header.Get("X-Dgraph-AuthToken"); token != "" {
user = getUser(token, true)
} else {
user = getUser("", false)
}

ip := ""
if peerInfo, ok := peer.FromContext(ctx); ok {
ip, _, _ = net.SplitHostPort(peerInfo.Addr.String())
}

auditor.Audit(&AuditEvent{
User: user,
Namespace: namespace,
ServerHost: x.WorkerConfig.MyAddr,
ClientHost: ip,
Endpoint: "/graphql",
ReqType: WebSocket,
Req: truncate(req.Query, maxReqLength),
Status: http.StatusText(http.StatusOK),
QueryParams: nil,
})
}

func auditGrpc(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) {
clientHost := ""
if p, ok := peer.FromContext(ctx); ok {
Expand Down
3 changes: 3 additions & 0 deletions graphql/admin/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"compress/gzip"
"context"
"encoding/json"
"github.com/dgraph-io/dgraph/ee/audit"
"io"
"io/ioutil"
"mime"
Expand Down Expand Up @@ -194,6 +195,8 @@ func (gs *graphqlSubscription) Subscribe(
Variables: variableValues,
Header: reqHeader,
}

audit.AuditWebSockets(ctx, req)
namespace := x.ExtractNamespaceHTTP(&http.Request{Header: reqHeader})
glog.Infof("namespace: %d. Got GraphQL request over websocket.", namespace)
// first load the schema, then do anything else
Expand Down
17 changes: 10 additions & 7 deletions x/jwt_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,22 @@ func ExtractUserName(jwtToken string) (string, error) {
return userId, nil
}

func ExtractJWTNamespace(ctx context.Context) (uint64, error) {
jwtString, err := ExtractJwt(ctx)
if err != nil {
return 0, err
}
claims, err := ParseJWT(jwtString)
func ExtractNamespaceFromJwt(jwtToken string) (uint64, error) {
claims, err := ParseJWT(jwtToken)
if err != nil {
return 0, err
}

namespace, ok := claims["namespace"].(float64)
if !ok {
return 0, errors.Errorf("namespace in claims is not valid:%v", namespace)
}
return uint64(namespace), nil
}

func ExtractJWTNamespace(ctx context.Context) (uint64, error) {
jwtString, err := ExtractJwt(ctx)
if err != nil {
return 0, err
}
return ExtractNamespaceFromJwt(jwtString)
}

0 comments on commit 9792506

Please sign in to comment.