diff --git a/contrib/scripts/functions.sh b/contrib/scripts/functions.sh index 2283f342ba2..51577821105 100755 --- a/contrib/scripts/functions.sh +++ b/contrib/scripts/functions.sh @@ -14,9 +14,8 @@ function restartCluster { basedir=$GOPATH/src/github.com/dgraph-io/dgraph pushd $basedir/dgraph >/dev/null go build . && go install . && md5sum dgraph $GOPATH/bin/dgraph - docker ps --filter label="cluster=test" --format "{{.Names}}" \ - | xargs -r docker stop | sed 's/^/Stopped /' - docker-compose -f $compose_file -p dgraph up --force-recreate --remove-orphans --detach + docker ps -a --filter label="cluster=test" --format "{{.Names}}" | xargs docker rm -f + docker-compose -f $compose_file up --force-recreate --remove-orphans --detach popd >/dev/null $basedir/contrib/wait-for-it.sh -t 60 localhost:6080 || exit 1 diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index ab05843f149..a2576b8b8d9 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -33,6 +33,8 @@ import ( "syscall" "time" + "github.com/dgraph-io/badger/y" + "github.com/dgraph-io/dgo/protos/api" "github.com/dgraph-io/dgraph/edgraph" "github.com/dgraph-io/dgraph/posting" @@ -126,11 +128,14 @@ they form a Raft group and provide synchronous replication. "If set, all Alter requests to Dgraph would need to have this token."+ " The token can be passed as follows: For HTTP requests, in X-Dgraph-AuthToken header."+ " For Grpc, in auth-token key in the context.") + flag.String("hmac_secret_file", "", "The file storing the HMAC secret"+ " that is used for signing the JWT. Enterprise feature.") - flag.Duration("access_jwt_ttl", 6*time.Hour, "The TTL for the access jwt. "+ + flag.Duration("acl_access_ttl", 6*time.Hour, "The TTL for the access jwt. "+ + "Enterprise feature.") + flag.Duration("acl_refresh_ttl", 30*24*time.Hour, "The TTL for the refresh jwt. "+ "Enterprise feature.") - flag.Duration("refresh_jwt_ttl", 30*24*time.Hour, "The TTL for the refresh jwt. "+ + flag.Duration("acl_cache_ttl", 30*time.Second, "The interval to refresh the acl cache. "+ "Enterprise feature.") flag.Float64P("lru_mb", "l", -1, "Estimated memory the LRU cache can take. "+ @@ -408,14 +413,25 @@ func run() { secretFile := Alpha.Conf.GetString("hmac_secret_file") if secretFile != "" { + if !Alpha.Conf.GetBool("enterprise_features") { + glog.Errorf("You must enable Dgraph enterprise features with the " + + "--enterprise_features option in order to use ACL.") + os.Exit(1) + } + hmacSecret, err := ioutil.ReadFile(secretFile) if err != nil { glog.Fatalf("Unable to read HMAC secret from file: %v", secretFile) } + if len(hmacSecret) < 32 { + glog.Errorf("The HMAC secret file should contain at least 256 bits (32 ascii chars)") + os.Exit(1) + } opts.HmacSecret = hmacSecret - opts.AccessJwtTtl = Alpha.Conf.GetDuration("access_jwt_ttl") - opts.RefreshJwtTtl = Alpha.Conf.GetDuration("refresh_jwt_ttl") + opts.AccessJwtTtl = Alpha.Conf.GetDuration("acl_access_ttl") + opts.RefreshJwtTtl = Alpha.Conf.GetDuration("acl_refresh_ttl") + opts.AclRefreshInterval = Alpha.Conf.GetDuration("acl_cache_ttl") glog.Info("HMAC secret loaded successfully.") } @@ -516,9 +532,18 @@ func run() { _ = numShutDownSig // Setup external communication. - go worker.StartRaftNodes(edgraph.State.WALstore, bindall) + aclCloser := y.NewCloser(1) + go func() { + worker.StartRaftNodes(edgraph.State.WALstore, bindall) + // initialization of the admin account can only be done after raft nodes are running + // and health check passes + edgraph.ResetAcl() + edgraph.RefreshAcls(aclCloser) + }() + setupServer() glog.Infoln("GRPC and HTTP stopped.") + aclCloser.SignalAndWait() worker.BlockingStop() glog.Infoln("Server shutdown. Bye!") } diff --git a/edgraph/access.go b/edgraph/access.go index f347f489294..b5317704f8b 100644 --- a/edgraph/access.go +++ b/edgraph/access.go @@ -21,6 +21,7 @@ package edgraph import ( "context" + "github.com/dgraph-io/badger/y" "github.com/dgraph-io/dgo/protos/api" "github.com/dgraph-io/dgraph/x" "github.com/golang/glog" @@ -32,3 +33,26 @@ func (s *Server) Login(ctx context.Context, glog.Warningf("Login failed: %s", x.ErrNotSupported) return &api.Response{}, x.ErrNotSupported } + +func ResetAcl() { + // do nothing +} + +func RefreshAcls(closer *y.Closer) { + // do nothing + <-closer.HasBeenClosed() + closer.Done() +} + +func authorizeAlter(ctx context.Context, op *api.Operation) error { + return nil +} + +func authorizeMutation(ctx context.Context, mu *api.Mutation) error { + return nil +} + +func authorizeQuery(ctx context.Context, req *api.Request) error { + // always allow access + return nil +} diff --git a/edgraph/access_ee.go b/edgraph/access_ee.go index c4a21a21a41..337957415b2 100644 --- a/edgraph/access_ee.go +++ b/edgraph/access_ee.go @@ -14,18 +14,24 @@ package edgraph import ( "context" - "encoding/json" "fmt" - "strconv" + "sync" "time" + "github.com/dgraph-io/badger/y" + "github.com/dgraph-io/dgo/protos/api" "github.com/dgraph-io/dgraph/ee/acl" - "github.com/dgrijalva/jwt-go" + "github.com/dgraph-io/dgraph/gql" + "github.com/dgraph-io/dgraph/schema" + "github.com/dgraph-io/dgraph/x" + jwt "github.com/dgrijalva/jwt-go" "github.com/golang/glog" - "google.golang.org/grpc/peer" - otrace "go.opencensus.io/trace" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" ) func (s *Server) Login(ctx context.Context, @@ -37,31 +43,31 @@ func (s *Server) Login(ctx context.Context, var addr string if ip, ok := peer.FromContext(ctx); ok { addr = ip.Addr.String() - glog.Infof("Login request from: %s", addr) + glog.Infof("login request from: %s", addr) span.Annotate([]otrace.Attribute{ otrace.StringAttribute("client_ip", addr), }, "client ip for login") } - user, err := s.authenticate(ctx, request) + user, err := s.authenticateLogin(ctx, request) if err != nil { - errMsg := fmt.Sprintf("Authentication from address %s failed: %v", addr, err) + errMsg := fmt.Sprintf("authentication from address %s failed: %v", addr, err) glog.Errorf(errMsg) return nil, fmt.Errorf(errMsg) } resp := &api.Response{} - accessJwt, err := getAccessJwt(request.Userid, user.Groups) + accessJwt, err := getAccessJwt(user.UserID, user.Groups) if err != nil { - errMsg := fmt.Sprintf("Unable to get access jwt (userid=%s,addr=%s):%v", - request.Userid, addr, err) + errMsg := fmt.Sprintf("unable to get access jwt (userid=%s,addr=%s):%v", + user.UserID, addr, err) glog.Errorf(errMsg) return nil, fmt.Errorf(errMsg) } - refreshJwt, err := getRefreshJwt(request.Userid) + refreshJwt, err := getRefreshJwt(user.UserID) if err != nil { - errMsg := fmt.Sprintf("Unable to get refresh jwt (userid=%s,addr=%s):%v", - request.Userid, addr, err) + errMsg := fmt.Sprintf("unable to get refresh jwt (userid=%s,addr=%s):%v", + user.UserID, addr, err) glog.Errorf(errMsg) return nil, fmt.Errorf(errMsg) } @@ -73,8 +79,8 @@ func (s *Server) Login(ctx context.Context, jwtBytes, err := loginJwt.Marshal() if err != nil { - errMsg := fmt.Sprintf("Unable to marshal jwt (userid=%s,addr=%s):%v", - request.Userid, addr, err) + errMsg := fmt.Sprintf("unable to marshal jwt (userid=%s,addr=%s):%v", + user.UserID, addr, err) glog.Errorf(errMsg) return nil, fmt.Errorf(errMsg) } @@ -82,81 +88,110 @@ func (s *Server) Login(ctx context.Context, return resp, nil } -func (s *Server) authenticate(ctx context.Context, request *api.LoginRequest) (*acl.User, error) { +// authenticateLogin authenticates the login request using either the refresh token if present, or +// the pair. If authentication passes, it queries the user's uid and associated +// groups from DB and returns the user object +func (s *Server) authenticateLogin(ctx context.Context, request *api.LoginRequest) (*acl.User, + error) { if err := validateLoginRequest(request); err != nil { - return nil, fmt.Errorf("Invalid login request: %v", err) + return nil, fmt.Errorf("invalid login request: %v", err) } var user *acl.User if len(request.RefreshToken) > 0 { - userId, err := authenticateRefreshToken(request.RefreshToken) + userData, err := validateToken(request.RefreshToken) if err != nil { - return nil, fmt.Errorf("Unable to authenticate the refresh token %v: %v", + return nil, fmt.Errorf("unable to authenticate the refresh token %v: %v", request.RefreshToken, err) } - user, err = s.queryUser(ctx, userId, "") + userId := userData[0] + user, err = authorizeUser(ctx, userId, "") if err != nil { - return nil, fmt.Errorf("Error while querying user with id: %v", - request.Userid) + return nil, fmt.Errorf("error while querying user with id %v: %v", userId, err) } if user == nil { - return nil, fmt.Errorf("User not found for id %v", request.Userid) - } - } else { - var err error - user, err = s.queryUser(ctx, request.Userid, request.Password) - if err != nil { - return nil, fmt.Errorf("Error while querying user with id: %v", - request.Userid) + return nil, fmt.Errorf("unable to authenticate through refresh token: "+ + "user not found for id %v", userId) } - if user == nil { - return nil, fmt.Errorf("User not found for id %v", request.Userid) - } - if !user.PasswordMatch { - return nil, fmt.Errorf("Password mismatch for user: %v", request.Userid) - } + glog.Infof("authenticated user %s through refresh token", userId) + return user, nil } + // authorize the user using password + var err error + user, err = authorizeUser(ctx, request.Userid, request.Password) + if err != nil { + return nil, fmt.Errorf("error while querying user with id %v: %v", + request.Userid, err) + } + + if user == nil { + return nil, fmt.Errorf("unable to authenticate through password: "+ + "user not found for id %v", request.Userid) + } + if !user.PasswordMatch { + return nil, fmt.Errorf("password mismatch for user: %v", request.Userid) + } return user, nil } -func authenticateRefreshToken(refreshToken string) (string, error) { - token, err := jwt.Parse(refreshToken, func(token *jwt.Token) (interface{}, error) { +// validateToken verifies the signature and expiration of the jwt, and if validation passes, +// returns a slice of strings, where the first element is the extracted userId +// and the rest are groupIds encoded in the jwt. +func validateToken(jwtStr string) ([]string, error) { + token, err := jwt.Parse(jwtStr, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return Config.HmacSecret, nil }) if err != nil { - return "", fmt.Errorf("Unable to parse refresh token:%v", err) + return nil, fmt.Errorf("unable to parse jwt token:%v", err) } claims, ok := token.Claims.(jwt.MapClaims) if !ok || !token.Valid { - return "", fmt.Errorf("Claims in refresh token is not map claims:%v", refreshToken) + return nil, fmt.Errorf("claims in jwt token is not map claims") } // by default, the MapClaims.Valid will return true if the exp field is not set // here we enforce the checking to make sure that the refresh token has not expired now := time.Now().Unix() if !claims.VerifyExpiresAt(now, true) { - return "", fmt.Errorf("Refresh token has expired: %v", refreshToken) + return nil, fmt.Errorf("Token is expired") // the same error msg that's used inside jwt-go } userId, ok := claims["userid"].(string) if !ok { - return "", fmt.Errorf("User ID in claims is not a string:%v", userId) + return nil, fmt.Errorf("userid in claims is not a string:%v", userId) + } + + groups, ok := claims["groups"].([]interface{}) + var groupIds []string + if ok { + groupIds = make([]string, 0, len(groups)) + for _, group := range groups { + groupId, ok := group.(string) + if !ok { + // This shouldn't happen. So, no need to make the client try to refresh the tokens. + return nil, fmt.Errorf("unable to convert group to string:%v", group) + } + + groupIds = append(groupIds, groupId) + } } - return userId, nil + return append([]string{userId}, groupIds...), nil } +// validateLoginRequest validates that the login request has either the refresh token or the +// pair func validateLoginRequest(request *api.LoginRequest) error { if request == nil { - return fmt.Errorf("The request should not be nil") + return fmt.Errorf("the request should not be nil") } // we will use the refresh token for authentication if it's set if len(request.RefreshToken) > 0 { @@ -165,41 +200,42 @@ func validateLoginRequest(request *api.LoginRequest) error { // otherwise make sure both userid and password are set if len(request.Userid) == 0 { - return fmt.Errorf("The userid should not be empty") + return fmt.Errorf("the userid should not be empty") } if len(request.Password) == 0 { - return fmt.Errorf("The password should not be empty") + return fmt.Errorf("the password should not be empty") } return nil } +// getAccessJwt constructs an access jwt with the given user id, groupIds, +// and expiration TTL specified by Config.AccessJwtTtl func getAccessJwt(userId string, groups []acl.Group) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "userid": userId, "groups": acl.GetGroupIDs(groups), // set the jwt exp according to the ttl - "exp": json.Number( - strconv.FormatInt(time.Now().Add(Config.AccessJwtTtl).Unix(), 10)), + "exp": time.Now().Add(Config.AccessJwtTtl).Unix(), }) jwtString, err := token.SignedString(Config.HmacSecret) if err != nil { - return "", fmt.Errorf("Unable to encode jwt to string: %v", err) + return "", fmt.Errorf("unable to encode jwt to string: %v", err) } return jwtString, nil } +// getRefreshJwt constructs a refresh jwt with the given user id, and expiration ttl specified by +// Config.RefreshJwtTtl func getRefreshJwt(userId string) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "userid": userId, - // set the jwt exp according to the ttl - "exp": json.Number( - strconv.FormatInt(time.Now().Add(Config.RefreshJwtTtl).Unix(), 10)), + "exp": time.Now().Add(Config.RefreshJwtTtl).Unix(), }) jwtString, err := token.SignedString(Config.HmacSecret) if err != nil { - return "", fmt.Errorf("Unable to encode jwt to string: %v", err) + return "", fmt.Errorf("unable to encode jwt to string: %v", err) } return jwtString, nil } @@ -208,6 +244,7 @@ const queryUser = ` query search($userid: string, $password: string){ user(func: eq(dgraph.xid, $userid)) { uid + dgraph.xid password_match: checkpwd(dgraph.password, $password) dgraph.user.group { uid @@ -216,8 +253,10 @@ const queryUser = ` } }` -func (s *Server) queryUser(ctx context.Context, userid string, password string) (user *acl.User, - err error) { +// authorizeUser queries the user with the given user id, and returns the associated uid, +// acl groups, and whether the password stored in DB matches the supplied password +func authorizeUser(ctx context.Context, userid string, password string) (*acl.User, + error) { queryVars := map[string]string{ "$userid": userid, "$password": password, @@ -227,14 +266,339 @@ func (s *Server) queryUser(ctx context.Context, userid string, password string) Vars: queryVars, } - queryResp, err := s.Query(ctx, &queryRequest) + queryResp, err := (&Server{}).doQuery(ctx, &queryRequest) if err != nil { glog.Errorf("Error while query user with id %s: %v", userid, err) return nil, err } - user, err = acl.UnmarshalUser(queryResp, "user") + user, err := acl.UnmarshalUser(queryResp, "user") if err != nil { return nil, err } return user, nil } + +func RefreshAcls(closer *y.Closer) { + defer closer.Done() + if len(Config.HmacSecret) == 0 { + // the acl feature is not turned on + return + } + + ticker := time.NewTicker(Config.AclRefreshInterval) + defer ticker.Stop() + + // retrieve the full data set of ACLs from the corresponding alpha server, and update the + // aclCache + retrieveAcls := func() error { + glog.V(1).Infof("Refreshing ACLs") + queryRequest := api.Request{ + Query: queryAcls, + } + + ctx := context.Background() + var err error + queryResp, err := (&Server{}).doQuery(ctx, &queryRequest) + if err != nil { + return fmt.Errorf("unable to retrieve acls: %v", err) + } + groups, err := acl.UnmarshalGroups(queryResp.GetJson(), "allAcls") + if err != nil { + return err + } + + storedEntries := 0 + for _, group := range groups { + // convert the serialized acl into a map for easy lookups + group.MappedAcls, err = acl.UnmarshalAcl([]byte(group.Acls)) + if err != nil { + glog.Errorf("Error while unmarshalling ACLs for group %v:%v", group, err) + continue + } + + storedEntries++ + aclCache.Store(group.GroupID, &group) + } + glog.V(1).Infof("Updated the ACL cache with %d entries", storedEntries) + return nil + } + + for { + select { + case <-closer.HasBeenClosed(): + return + case <-ticker.C: + if err := retrieveAcls(); err != nil { + glog.Errorf("Error while retrieving acls:%v", err) + } + } + } +} + +const queryAcls = ` +{ + allAcls(func: has(dgraph.group.acl)) { + dgraph.xid + dgraph.group.acl + } +} +` + +// the acl cache mapping group names to the corresponding group acls +var aclCache sync.Map + +// clear the aclCache and upsert the Groot account. +func ResetAcl() { + if len(Config.HmacSecret) == 0 { + // the acl feature is not turned on + return + } + + upsertGroot := func(ctx context.Context) error { + queryVars := map[string]string{ + "$userid": x.GrootId, + "$password": "", + } + queryRequest := api.Request{ + Query: queryUser, + Vars: queryVars, + } + + queryResp, err := (&Server{}).doQuery(ctx, &queryRequest) + if err != nil { + return fmt.Errorf("error while querying user with id %s: %v", x.GrootId, err) + } + startTs := queryResp.GetTxn().StartTs + + rootUser, err := acl.UnmarshalUser(queryResp, "user") + if err != nil { + return fmt.Errorf("error while unmarshaling the root user: %v", err) + } + if rootUser != nil { + // the user already exists, no need to create + return nil + } + + // Insert Groot. + createUserNQuads := []*api.NQuad{ + { + Subject: "_:newuser", + Predicate: "dgraph.xid", + ObjectValue: &api.Value{Val: &api.Value_StrVal{StrVal: x.GrootId}}, + }, + { + Subject: "_:newuser", + Predicate: "dgraph.password", + ObjectValue: &api.Value{Val: &api.Value_StrVal{StrVal: "password"}}, + }} + + mu := &api.Mutation{ + StartTs: startTs, + CommitNow: true, + Set: createUserNQuads, + } + + if _, err := (&Server{}).doMutate(context.Background(), mu); err != nil { + return err + } + return nil + } + + aclCache = sync.Map{} + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + if err := upsertGroot(ctx); err != nil { + glog.Infof("Unable to upsert the groot account. Error: %v", err) + time.Sleep(100 * time.Millisecond) + } else { + return + } + } +} + +// extract the userId, groupIds from the accessJwt in the context +func extractUserAndGroups(ctx context.Context) ([]string, error) { + // extract the jwt and unmarshal the jwt to get the list of groups + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, fmt.Errorf("no metadata available") + } + accessJwt := md.Get("accessJwt") + if len(accessJwt) == 0 { + return nil, fmt.Errorf("no accessJwt available") + } + + return validateToken(accessJwt[0]) +} + +//authorizeAlter parses the Schema in the operation and authorizes the operation using the aclCache +func authorizeAlter(ctx context.Context, op *api.Operation) error { + if len(Config.HmacSecret) == 0 { + // the user has not turned on the acl feature + return nil + } + + userData, err := extractUserAndGroups(ctx) + if err != nil { + return status.Error(codes.Unauthenticated, err.Error()) + } + if isGroot(userData) { + return nil + } + + // if we get here, we know the user is not Groot. + if op.DropAll { + return fmt.Errorf("only Groot is allowed to drop all data, current user is %s", userData[0]) + } + + groupIds := userData[1:] + if len(op.DropAttr) > 0 { + // check that we have the modify permission on the predicate + if err := authorizePredicate(groupIds, op.DropAttr, acl.Modify); err != nil { + return status.Error(codes.PermissionDenied, + fmt.Sprintf("unauthorized to alter the predicate:%v", err)) + } + return nil + } + + updates, err := schema.Parse(op.Schema) + if err != nil { + return err + } + for _, update := range updates { + if err := authorizePredicate(groupIds, update.Predicate, acl.Modify); err != nil { + return status.Error(codes.PermissionDenied, + fmt.Sprintf("unauthorized to alter the predicate: %v", err)) + } + } + return nil +} + +// parsePredsFromMutation returns a union set of all the predicate names in the input nquads +func parsePredsFromMutation(nquads []*api.NQuad) map[string]struct{} { + preds := make(map[string]struct{}) + for _, nquad := range nquads { + preds[nquad.Predicate] = struct{}{} + } + return preds +} + +// authorizeMutation authorizes the mutation using the aclCache +func authorizeMutation(ctx context.Context, mu *api.Mutation) error { + if len(Config.HmacSecret) == 0 { + // the user has not turned on the acl feature + return nil + } + + userData, err := extractUserAndGroups(ctx) + if err != nil { + return status.Error(codes.Unauthenticated, err.Error()) + } + if isGroot(userData) { + // Groot has access to everything. + return nil + } + + gmu, err := parseMutationObject(mu) + if err != nil { + return err + } + + groupIds := userData[1:] + for pred := range parsePredsFromMutation(gmu.Set) { + if err := authorizePredicate(groupIds, pred, acl.Write); err != nil { + return status.Error(codes.PermissionDenied, + fmt.Sprintf("unauthorized to mutate the predicate: %v", err)) + } + } + return nil +} + +func parsePredsFromQuery(gqls []*gql.GraphQuery) map[string]struct{} { + preds := make(map[string]struct{}) + for _, gq := range gqls { + + if gq.Func != nil { + preds[gq.Func.Attr] = struct{}{} + } + + if len(gq.Attr) > 0 { + preds[gq.Attr] = struct{}{} + } + + for childPred := range parsePredsFromQuery(gq.Children) { + preds[childPred] = struct{}{} + } + } + return preds +} + +func isGroot(userData []string) bool { + if len(userData) == 0 { + return false + } + + return userData[0] == x.GrootId +} + +//authorizeQuery authorizes the query using the aclCache +func authorizeQuery(ctx context.Context, req *api.Request) error { + if len(Config.HmacSecret) == 0 { + // the user has not turned on the acl feature + return nil + } + + userData, err := extractUserAndGroups(ctx) + if err != nil { + return status.Error(codes.Unauthenticated, err.Error()) + } + if isGroot(userData) { + return nil + } + + parsedReq, err := gql.Parse(gql.Request{ + Str: req.Query, + Variables: req.Vars, + }) + if err != nil { + return err + } + + groupIds := userData[1:] + for pred := range parsePredsFromQuery(parsedReq.Query) { + if err := authorizePredicate(groupIds, pred, acl.Read); err != nil { + return status.Error(codes.PermissionDenied, + fmt.Sprintf("unauthorized to query the predicate: %v", err)) + } + } + return nil +} + +func authorizePredicate(groups []string, predicate string, operation *acl.Operation) error { + for _, group := range groups { + if err := hasAccess(group, predicate, operation); err == nil { + return nil + } + } + return fmt.Errorf("unauthorized to do %s on predicate %s", operation.Name, predicate) +} + +// hasAccess checks the aclCache and returns whether the specified group is authorized to perform +// the operation on the given predicate +func hasAccess(groupId string, predicate string, operation *acl.Operation) error { + entry, found := aclCache.Load(groupId) + if !found { + return fmt.Errorf("acl not found for group %v", groupId) + } + aclGroup := entry.(*acl.Group) + perm, found := aclGroup.MappedAcls[predicate] + allowed := found && (perm&operation.Code) != 0 + glog.V(1).Infof("Authorizing group %v on predicate %v for %s, allowed %v", groupId, + predicate, operation.Name, allowed) + if !allowed { + return fmt.Errorf("group %s not allowed to do %s on predicate %s", + groupId, operation.Name, predicate) + } + return nil +} diff --git a/edgraph/config.go b/edgraph/config.go index a25232657ce..9144d53c394 100644 --- a/edgraph/config.go +++ b/edgraph/config.go @@ -33,18 +33,18 @@ const ( ) type Options struct { - PostingDir string - BadgerTables string - BadgerVlog string - WALDir string - MutationsMode int - AuthToken string - + PostingDir string + BadgerTables string + BadgerVlog string + WALDir string + MutationsMode int + AuthToken string AllottedMemory float64 - HmacSecret []byte - AccessJwtTtl time.Duration - RefreshJwtTtl time.Duration + HmacSecret []byte + AccessJwtTtl time.Duration + RefreshJwtTtl time.Duration + AclRefreshInterval time.Duration } var Config Options diff --git a/edgraph/server.go b/edgraph/server.go index 40cb5ac9218..c8b78f5d253 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -41,7 +41,6 @@ import ( "github.com/dgraph-io/dgraph/types/facets" "github.com/dgraph-io/dgraph/worker" "github.com/dgraph-io/dgraph/x" - "github.com/golang/glog" otrace "go.opencensus.io/trace" "golang.org/x/net/context" @@ -283,6 +282,7 @@ func (s *Server) Alter(ctx context.Context, op *api.Operation) (*api.Payload, er if err := x.HealthCheck(); err != nil { return empty, err } + if !isMutationAllowed(ctx) { return nil, x.Errorf("No mutations allowed by server.") } @@ -290,6 +290,12 @@ func (s *Server) Alter(ctx context.Context, op *api.Operation) (*api.Payload, er glog.Warningf("Alter denied with error: %v\n", err) return nil, err } + + if err := authorizeAlter(ctx, op); err != nil { + glog.Warningf("Alter denied with error: %v\n", err) + return nil, err + } + // All checks done. defer glog.Infof("ALTER op: %+v done", op) @@ -299,6 +305,9 @@ func (s *Server) Alter(ctx context.Context, op *api.Operation) (*api.Payload, er if op.DropAll { m.DropAll = true _, err := query.ApplyMutations(ctx, m) + + // recreate the admin account after a drop all operation + ResetAcl() return empty, err } if len(op.DropAttr) > 0 { @@ -317,6 +326,7 @@ func (s *Server) Alter(ctx context.Context, op *api.Operation) (*api.Payload, er _, err = query.ApplyMutations(ctx, m) return empty, err } + updates, err := schema.Parse(op.Schema) if err != nil { return empty, err @@ -333,6 +343,14 @@ func annotateStartTs(span *otrace.Span, ts uint64) { } func (s *Server) Mutate(ctx context.Context, mu *api.Mutation) (resp *api.Assigned, err error) { + if err := authorizeMutation(ctx, mu); err != nil { + return nil, err + } + + return s.doMutate(ctx, mu) +} + +func (s *Server) doMutate(ctx context.Context, mu *api.Mutation) (resp *api.Assigned, err error) { ctx, span := otrace.StartSpan(ctx, "Server.Mutate") defer span.End() @@ -371,6 +389,7 @@ func (s *Server) Mutate(ctx context.Context, mu *api.Mutation) (resp *api.Assign parseEnd := time.Now() l.Parsing = parseEnd.Sub(l.Start) + defer func() { l.Processing = time.Since(parseEnd) resp.Latency = &api.Latency{ @@ -438,9 +457,17 @@ func (s *Server) Mutate(ctx context.Context, mu *api.Mutation) (resp *api.Assign return resp, nil } +func (s *Server) Query(ctx context.Context, req *api.Request) (*api.Response, error) { + if err := authorizeQuery(ctx, req); err != nil { + return nil, err + } + + return s.doQuery(ctx, req) +} + // This method is used to execute the query and return the response to the // client as a protocol buffer message. -func (s *Server) Query(ctx context.Context, req *api.Request) (resp *api.Response, err error) { +func (s *Server) doQuery(ctx context.Context, req *api.Request) (*api.Response, error) { if glog.V(3) { glog.Infof("Got a query: %+v", req) } @@ -448,17 +475,17 @@ func (s *Server) Query(ctx context.Context, req *api.Request) (resp *api.Respons defer span.End() if err := x.HealthCheck(); err != nil { - return resp, err + return nil, err } x.PendingQueries.Add(1) x.NumQueries.Add(1) defer x.PendingQueries.Add(-1) if ctx.Err() != nil { - return resp, ctx.Err() + return nil, ctx.Err() } - resp = new(api.Response) + resp := new(api.Response) if len(req.Query) == 0 { span.Annotate(nil, "Empty query") return resp, fmt.Errorf("Empty query") @@ -475,7 +502,6 @@ func (s *Server) Query(ctx context.Context, req *api.Request) (resp *api.Respons if err != nil { return resp, err } - if req.StartTs == 0 { req.StartTs = State.getTimestamp(req.ReadOnly) } diff --git a/ee/acl/acl_test.go b/ee/acl/acl_test.go index 806c8fcc78c..c08391b726a 100644 --- a/ee/acl/acl_test.go +++ b/ee/acl/acl_test.go @@ -13,8 +13,21 @@ package acl import ( + "context" + "fmt" + "log" + "os" "os/exec" + "path/filepath" + "strconv" "testing" + "time" + + "github.com/dgraph-io/dgo" + "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" + "github.com/stretchr/testify/require" ) const ( @@ -23,11 +36,6 @@ const ( dgraphEndpoint = "localhost:9180" ) -func TestAcl(t *testing.T) { - t.Run("create user", CreateAndDeleteUsers) - // t.Run("login", LogIn) -} - func checkOutput(t *testing.T, cmd *exec.Cmd, shouldFail bool) string { out, err := cmd.CombinedOutput() if (!shouldFail && err != nil) || (shouldFail && err == nil) { @@ -38,85 +46,223 @@ func checkOutput(t *testing.T, cmd *exec.Cmd, shouldFail bool) string { return string(out) } -func CreateAndDeleteUsers(t *testing.T) { +func TestCreateAndDeleteUsers(t *testing.T) { + // clean up the user to allow repeated running of this test + cleanUserCmd := exec.Command("dgraph", "acl", "userdel", "-d", dgraphEndpoint, + "-u", userid, "-x", "password") + cleanUserCmd.Run() + glog.Infof("cleaned up db user state") + createUserCmd1 := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", userid, - "-p", userpassword) - createUserOutput1 := checkOutput(t, createUserCmd1, false) - t.Logf("Got output when creating user:%v", createUserOutput1) + "-p", userpassword, "-x", "password") + checkOutput(t, createUserCmd1, false) createUserCmd2 := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", userid, - "-p", userpassword) - + "-p", userpassword, "-x", "password") // create the user again should fail - createUserOutput2 := checkOutput(t, createUserCmd2, true) - t.Logf("Got output when creating user:%v", createUserOutput2) + checkOutput(t, createUserCmd2, true) // delete the user - deleteUserCmd := exec.Command("dgraph", "acl", "userdel", "-d", dgraphEndpoint, "-u", userid) - deleteUserOutput := checkOutput(t, deleteUserCmd, false) - t.Logf("Got output when deleting user:%v", deleteUserOutput) + deleteUserCmd := exec.Command("dgraph", "acl", "userdel", "-d", dgraphEndpoint, "-u", userid, + "-x", "password") + checkOutput(t, deleteUserCmd, false) // now we should be able to create the user again createUserCmd3 := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", userid, - "-p", userpassword) - createUserOutput3 := checkOutput(t, createUserCmd3, false) - t.Logf("Got output when creating user:%v", createUserOutput3) + "-p", userpassword, "-x", "password") + checkOutput(t, createUserCmd3, false) } -// TODO(gitlw): Finish this later. -// func LogIn(t *testing.T) { -// delete and recreate the user to ensure a clean state -/* - deleteUserCmd := exec.Command("dgraph", "acl", "userdel", "-d", dgraphEndpoint, "-u", "lucas") - deleteUserOutput := checkOutput(t, deleteUserCmd, false) - createUserCmd := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", "lucas", - "-p", "haha") - createUserOutput := checkOutput(t, createUserCmd, false) -*/ +func resetUser(t *testing.T) { + // delete and recreate the user to ensure a clean state + deleteUserCmd := exec.Command("dgraph", "acl", "userdel", "-d", dgraphEndpoint, + "-u", userid, "-x", "password") + deleteUserCmd.Run() + glog.Infof("deleted user") -// now try to login with the wrong password + createUserCmd := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", + userid, "-p", userpassword, "-x", "password") + checkOutput(t, createUserCmd, false) + glog.Infof("created user") +} -//loginWithWrongPassword(t, ctx, adminClient) -//loginWithCorrectPassword(t, ctx, adminClient) -// } +func TestAuthorization(t *testing.T) { + glog.Infof("testing with port 9180") + dg1, cancel := x.GetDgraphClientOnPort(9180) + defer cancel() + testAuthorization(t, dg1) + glog.Infof("done") -/* -func loginWithCorrectPassword(t *testing.T, ctx context.Context, - adminClient api.DgraphAccessClient) { - loginRequest := api.LogInRequest{ - Userid: userid, - Password: userpassword, + glog.Infof("testing with port 9182") + dg2, cancel := x.GetDgraphClientOnPort(9182) + defer cancel() + testAuthorization(t, dg2) + glog.Infof("done") +} + +func testAuthorization(t *testing.T, dg *dgo.Dgraph) { + createAccountAndData(t, dg) + ctx := context.Background() + if err := dg.Login(ctx, userid, userpassword); err != nil { + t.Fatalf("unable to login using the account %v", userid) } - response2, err := adminClient.LogIn(ctx, &loginRequest) - require.NoError(t, err) - if response2.Code != api.AclResponseCode_OK { - t.Errorf("Login with the correct password should result in the code %v", - api.AclResponseCode_OK) + + queryPredicateWithUserAccount(t, dg, true) + mutatePredicateWithUserAccount(t, dg, true) + alterPredicateWithUserAccount(t, dg, true) + createGroupAndAcls(t) + // wait for 35 seconds to ensure the new acl have reached all acl caches + log.Println("Sleeping for 35 seconds for acl to catch up") + time.Sleep(35 * time.Second) + queryPredicateWithUserAccount(t, dg, false) + // sleep long enough (10s per the docker-compose.yml in this directory) + // for the accessJwt to expire in order to test auto login through refresh jwt + log.Println("Sleeping for 12 seconds for accessJwt to expire") + time.Sleep(12 * time.Second) + mutatePredicateWithUserAccount(t, dg, false) + log.Println("Sleeping for 12 seconds for accessJwt to expire") + time.Sleep(12 * time.Second) + alterPredicateWithUserAccount(t, dg, false) +} + +var predicateToRead = "predicate_to_read" +var queryAttr = "name" +var predicateToWrite = "predicate_to_write" +var predicateToAlter = "predicate_to_alter" +var group = "dev" +var rootDir = filepath.Join(os.TempDir(), "acl_test") + +func queryPredicateWithUserAccount(t *testing.T, dg *dgo.Dgraph, shouldFail bool) { + // login with alice's account + ctx := context.Background() + txn := dg.NewTxn() + query := fmt.Sprintf(` + { + q(func: eq(%s, "SF")) { + %s + } + }`, predicateToRead, queryAttr) + txn = dg.NewTxn() + _, err := txn.Query(ctx, query) + + if shouldFail { + require.Error(t, err, "the query should have failed") + } else { + require.NoError(t, err, "the query should have succeeded") } - jwt := acl.Jwt{} - jwt.DecodeString(response2.Context.Jwt, false, nil) - if jwt.Payload.Userid != userid { - t.Errorf("the jwt token should have the user id encoded") +} + +func mutatePredicateWithUserAccount(t *testing.T, dg *dgo.Dgraph, shouldFail bool) { + ctx := context.Background() + txn := dg.NewTxn() + _, err := txn.Mutate(ctx, &api.Mutation{ + CommitNow: true, + SetNquads: []byte(fmt.Sprintf(`_:a <%s> "string" .`, predicateToWrite)), + }) + + if shouldFail { + require.Error(t, err, "the mutation should have failed") + } else { + require.NoError(t, err, "the mutation should have succeeded") } - jwtTime := time.Unix(jwt.Payload.Exp, 0) - jwtValidDays := jwtTime.Sub(time.Now()).Round(time.Hour).Hours() / 24 - if jwtValidDays != 30.0 { - t.Errorf("The jwt token should be valid for 30 days, received %v days", jwtValidDays) +} + +func alterPredicateWithUserAccount(t *testing.T, dg *dgo.Dgraph, shouldFail bool) { + ctx := context.Background() + err := dg.Alter(ctx, &api.Operation{ + Schema: fmt.Sprintf(`%s: int .`, predicateToAlter), + }) + if shouldFail { + require.Error(t, err, "the alter should have failed") + } else { + require.NoError(t, err, "the alter should have succeeded") } } -func loginWithWrongPassword(t *testing.T, ctx context.Context, - adminClient api.DgraphAccessClient) { - loginRequestWithWrongPassword := api.LogInRequest{ - Userid: userid, - Password: userpassword + "123", +func createAccountAndData(t *testing.T, dg *dgo.Dgraph) { + // use the groot account to clean the database + ctx := context.Background() + if err := dg.Login(ctx, x.GrootId, "password"); err != nil { + t.Fatalf("unable to login using the groot account:%v", err) + } + op := api.Operation{ + DropAll: true, } + if err := dg.Alter(ctx, &op); err != nil { + t.Fatalf("Unable to cleanup db:%v", err) + } + require.NoError(t, dg.Alter(ctx, &api.Operation{ + Schema: fmt.Sprintf(`%s: string @index(exact) .`, predicateToRead), + })) + + // create some data, e.g. user with name alice + resetUser(t) - response, err := adminClient.LogIn(ctx, &loginRequestWithWrongPassword) + txn := dg.NewTxn() + _, err := txn.Mutate(ctx, &api.Mutation{ + SetNquads: []byte(fmt.Sprintf("_:a <%s> \"SF\" .", predicateToRead)), + }) require.NoError(t, err) - if response.Code != api.AclResponseCode_UNAUTHENTICATED { - t.Errorf("Login with the wrong password should result in the code %v", api.AclResponseCode_UNAUTHENTICATED) - } + require.NoError(t, txn.Commit(ctx)) } -*/ +func createGroupAndAcls(t *testing.T) { + // create a new group + createGroupCmd := exec.Command(os.ExpandEnv("$GOPATH/bin/dgraph"), + "acl", "groupadd", + "-d", dgraphEndpoint, + "-g", group, "-x", "password") + if err := createGroupCmd.Run(); err != nil { + t.Fatalf("Unable to create group:%v", err) + } + + // add the user to the group + addUserToGroupCmd := exec.Command(os.ExpandEnv("$GOPATH/bin/dgraph"), + "acl", "usermod", + "-d", dgraphEndpoint, + "-u", userid, "-g", group, "-x", "password") + if err := addUserToGroupCmd.Run(); err != nil { + t.Fatalf("Unable to add user %s to group %s:%v", userid, group, err) + } + + // add READ permission on the predicateToRead to the group + addReadPermCmd1 := exec.Command(os.ExpandEnv("$GOPATH/bin/dgraph"), + "acl", "chmod", + "-d", dgraphEndpoint, + "-g", group, "-p", predicateToRead, "-P", strconv.Itoa(int(Read.Code)), "-x", + "password") + if err := addReadPermCmd1.Run(); err != nil { + t.Fatalf("Unable to add READ permission on %s to group %s:%v", + predicateToRead, group, err) + } + + // also add read permission to the attribute queryAttr, which is used inside the query block + addReadPermCmd2 := exec.Command(os.ExpandEnv("$GOPATH/bin/dgraph"), + "acl", "chmod", + "-d", dgraphEndpoint, + "-g", group, "-p", queryAttr, "-P", strconv.Itoa(int(Read.Code)), "-x", + "password") + if err := addReadPermCmd2.Run(); err != nil { + t.Fatalf("Unable to add READ permission on %s to group %s:%v", queryAttr, group, err) + } + + // add WRITE permission on the predicateToWrite + addWritePermCmd := exec.Command(os.ExpandEnv("$GOPATH/bin/dgraph"), + "acl", "chmod", + "-d", dgraphEndpoint, + "-g", group, "-p", predicateToWrite, "-P", strconv.Itoa(int(Write.Code)), "-x", + "password") + if err := addWritePermCmd.Run(); err != nil { + t.Fatalf("Unable to add permission on %s to group %s:%v", predicateToWrite, group, err) + } + + // add MODIFY permission on the predicateToAlter + addModifyPermCmd := exec.Command(os.ExpandEnv("$GOPATH/bin/dgraph"), + "acl", "chmod", + "-d", dgraphEndpoint, + "-g", group, "-p", predicateToAlter, "-P", strconv.Itoa(int(Modify.Code)), "-x", + "password") + if err := addModifyPermCmd.Run(); err != nil { + t.Fatalf("Unable to add permission on %s to group %s:%v", predicateToAlter, group, err) + } +} diff --git a/ee/acl/docker-compose.yml b/ee/acl/docker-compose.yml new file mode 100644 index 00000000000..efc15c85ca6 --- /dev/null +++ b/ee/acl/docker-compose.yml @@ -0,0 +1,63 @@ +# Docker compose file for testing. Use it with: +# docker-compose up --force-recreate +# This would pick up dgraph binary from $GOPATH. + +version: "3.5" +services: + zero1: + image: dgraph/dgraph:latest + container_name: acl-dg0.1 + working_dir: /data/dg0.1 + ports: + - 5080:5080 + - 6080:6080 + command: /gobin/dgraph zero --my=zero1:5080 --replicas 1 --idx 1 --bindall --expose_trace --profile_mode block --block_rate 10 --logtostderr -v=2 + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + labels: + cluster: test + + dg1: + image: dgraph/dgraph:latest + container_name: acl-dg1 + working_dir: /data/dg1 + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + - type: bind + source: $GOPATH/src/github.com/dgraph-io/dgraph/ee/acl + target: /dgraph-acl + ports: + - 8180:8180 + - 9180:9180 + security_opt: + - seccomp:unconfined + command: /gobin/dgraph alpha --my=dg1:7180 --lru_mb=1024 --zero=zero1:5080 -o 100 --expose_trace --trace 1.0 --profile_mode block --block_rate 10 --logtostderr -v=3 --hmac_secret_file /dgraph-acl/hmac-secret --enterprise_features --acl_access_ttl 10s + labels: + cluster: test + + dg2: + image: dgraph/dgraph:latest + container_name: acl-dg2 + working_dir: /data/dg2 + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + - type: bind + source: $GOPATH/src/github.com/dgraph-io/dgraph/ee/acl + target: /dgraph-acl + ports: + - 8182:8182 + - 9182:9182 + security_opt: + - seccomp:unconfined + command: /gobin/dgraph alpha --my=dg2:7182 --lru_mb=1024 --zero=zero1:5080 -o 102 --expose_trace --trace 1.0 --profile_mode block --block_rate 10 --logtostderr -v=3 --hmac_secret_file /dgraph-acl/hmac-secret --enterprise_features --acl_access_ttl 10s + labels: + cluster: test diff --git a/ee/acl/groups.go b/ee/acl/groups.go index dd7fa66e179..2ebffea497e 100644 --- a/ee/acl/groups.go +++ b/ee/acl/groups.go @@ -17,7 +17,6 @@ import ( "encoding/json" "fmt" "strings" - "time" "github.com/dgraph-io/dgo" "github.com/dgraph-io/dgo/protos/api" @@ -32,10 +31,13 @@ func groupAdd(conf *viper.Viper) error { return fmt.Errorf("The group id should not be empty") } - dc, close := getDgraphClient(conf) - defer close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dc, cancel, err := getClientWithAdminCtx(conf) defer cancel() + if err != nil { + return fmt.Errorf("unable to get admin context:%v", err) + } + + ctx := context.Background() txn := dc.NewTxn() defer func() { if err := txn.Discard(ctx); err != nil { @@ -77,10 +79,13 @@ func groupDel(conf *viper.Viper) error { return fmt.Errorf("The group id should not be empty") } - dc, close := getDgraphClient(conf) - defer close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dc, cancel, err := getClientWithAdminCtx(conf) defer cancel() + if err != nil { + return fmt.Errorf("unable to get admin context:%v", err) + } + + ctx := context.Background() txn := dc.NewTxn() defer func() { if err := txn.Discard(ctx); err != nil { @@ -156,10 +161,13 @@ func chMod(conf *viper.Viper) error { return fmt.Errorf("The predicate must not be empty") } - dc, close := getDgraphClient(conf) - defer close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dc, cancel, err := getClientWithAdminCtx(conf) defer cancel() + if err != nil { + return fmt.Errorf("unable to get admin context:%v", err) + } + + ctx := context.Background() txn := dc.NewTxn() defer func() { if err := txn.Discard(ctx); err != nil { diff --git a/ee/acl/hmac-secret b/ee/acl/hmac-secret new file mode 100644 index 00000000000..2add0c574b7 --- /dev/null +++ b/ee/acl/hmac-secret @@ -0,0 +1 @@ +1234567890123456789012345678901 diff --git a/ee/acl/run_ee.go b/ee/acl/run_ee.go index 68de7e1caf4..445134f83fd 100644 --- a/ee/acl/run_ee.go +++ b/ee/acl/run_ee.go @@ -18,7 +18,6 @@ import ( "fmt" "os" "strings" - "time" "github.com/dgraph-io/dgo" "github.com/dgraph-io/dgo/protos/api" @@ -32,10 +31,13 @@ type options struct { dgraph string } -var opt options -var tlsConf x.TLSHelperConfig +var ( + opt options + tlsConf x.TLSHelperConfig + CmdAcl x.SubCommand +) -var CmdAcl x.SubCommand +const gPassword = "gpassword" func init() { CmdAcl.Cmd = &cobra.Command{ @@ -44,7 +46,8 @@ func init() { } flag := CmdAcl.Cmd.PersistentFlags() - flag.StringP("dgraph", "d", "127.0.0.1:9080", "Dgraph gRPC server address") + flag.StringP("dgraph", "d", "127.0.0.1:9080", "Dgraph Alpha gRPC server address") + flag.StringP(gPassword, "x", "", "Groot password to authorize this operation") // TLS configuration x.RegisterTLSFlags(flag) @@ -96,22 +99,6 @@ func initSubcommands() []*x.SubCommand { userDelFlags := cmdUserDel.Cmd.Flags() userDelFlags.StringP("user", "u", "", "The user id to be deleted") - // login command - var cmdLogIn x.SubCommand - cmdLogIn.Cmd = &cobra.Command{ - Use: "login", - Short: "Login to dgraph in order to get a jwt token", - Run: func(cmd *cobra.Command, args []string) { - if err := userLogin(cmdLogIn.Conf); err != nil { - glog.Errorf("Unable to login:%v", err) - os.Exit(1) - } - }, - } - loginFlags := cmdLogIn.Cmd.Flags() - loginFlags.StringP("user", "u", "", "The user id to be created") - loginFlags.StringP("password", "p", "", "The password for the user") - // group creation command var cmdGroupAdd x.SubCommand cmdGroupAdd.Cmd = &cobra.Command{ @@ -193,7 +180,7 @@ func initSubcommands() []*x.SubCommand { infoFlags.StringP("user", "u", "", "The user to be shown") infoFlags.StringP("group", "g", "", "The group to be shown") return []*x.SubCommand{ - &cmdUserAdd, &cmdUserDel, &cmdLogIn, &cmdGroupAdd, &cmdGroupDel, &cmdUserMod, + &cmdUserAdd, &cmdUserDel, &cmdGroupAdd, &cmdGroupDel, &cmdUserMod, &cmdChMod, &cmdInfo, } } @@ -231,11 +218,13 @@ func info(conf *viper.Viper) error { (len(userId) != 0 && len(groupId) != 0) { return fmt.Errorf("Either the user or group should be specified, not both") } - - dc, close := getDgraphClient(conf) - defer close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dc, cancel, err := getClientWithAdminCtx(conf) defer cancel() + if err != nil { + return fmt.Errorf("unable to get admin context:%v", err) + } + + ctx := context.Background() txn := dc.NewTxn() defer func() { if err := txn.Discard(ctx); err != nil { diff --git a/ee/acl/users.go b/ee/acl/users.go index f83b9d396d5..1c4174bd200 100644 --- a/ee/acl/users.go +++ b/ee/acl/users.go @@ -16,7 +16,6 @@ import ( "context" "fmt" "strings" - "time" "github.com/dgraph-io/dgo" "github.com/dgraph-io/dgo/protos/api" @@ -36,10 +35,13 @@ func userAdd(conf *viper.Viper) error { return fmt.Errorf("The password must not be empty") } - dc, close := getDgraphClient(conf) - defer close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dc, cancel, err := getClientWithAdminCtx(conf) defer cancel() + if err != nil { + return fmt.Errorf("unable to get admin context:%v", err) + } + + ctx := context.Background() txn := dc.NewTxn() defer func() { if err := txn.Discard(ctx); err != nil { @@ -87,10 +89,13 @@ func userDel(conf *viper.Viper) error { return fmt.Errorf("The user id should not be empty") } - dc, close := getDgraphClient(conf) - defer close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dc, cancel, err := getClientWithAdminCtx(conf) defer cancel() + if err != nil { + return fmt.Errorf("unable to get admin context:%v", err) + } + + ctx := context.Background() txn := dc.NewTxn() defer func() { if err := txn.Discard(ctx); err != nil { @@ -127,37 +132,6 @@ func userDel(conf *viper.Viper) error { return nil } -func userLogin(conf *viper.Viper) error { - userid := conf.GetString("user") - password := conf.GetString("password") - - if len(userid) == 0 { - return fmt.Errorf("The user must not be empty") - } - if len(password) == 0 { - return fmt.Errorf("The password must not be empty") - } - - dc, close := getDgraphClient(conf) - defer close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - txn := dc.NewTxn() - defer func() { - if err := txn.Discard(ctx); err != nil { - glog.Errorf("Unable to discard transaction:%v", err) - } - }() - - if err := dc.Login(ctx, userid, password); err != nil { - return fmt.Errorf("Unable to login:%v", err) - } - updatedContext := dc.GetContext(ctx) - glog.Infof("Login successfully.\naccess jwt:\n%v\nrefresh jwt:\n%v", - updatedContext.Value("accessJwt"), updatedContext.Value("refreshJwt")) - return nil -} - func queryUser(ctx context.Context, txn *dgo.Txn, userid string) (user *User, err error) { query := ` query search($userid: string){ @@ -192,10 +166,13 @@ func userMod(conf *viper.Viper) error { return fmt.Errorf("The user must not be empty") } - dc, close := getDgraphClient(conf) - defer close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dc, cancel, err := getClientWithAdminCtx(conf) defer cancel() + if err != nil { + return fmt.Errorf("unable to get admin context:%v", err) + } + + ctx := context.Background() txn := dc.NewTxn() defer func() { if err := txn.Discard(ctx); err != nil { diff --git a/ee/acl/utils.go b/ee/acl/utils.go index 173ab28e673..60a852b8d82 100644 --- a/ee/acl/utils.go +++ b/ee/acl/utils.go @@ -13,12 +13,18 @@ package acl import ( + "context" "encoding/json" "fmt" + "syscall" + "time" + "github.com/dgraph-io/dgo" "github.com/dgraph-io/dgo/protos/api" "github.com/dgraph-io/dgraph/x" "github.com/golang/glog" + "github.com/spf13/viper" + "golang.org/x/crypto/ssh/terminal" ) func GetGroupIDs(groups []Group) []string { @@ -34,6 +40,26 @@ func GetGroupIDs(groups []Group) []string { return jwtGroups } +type Operation struct { + Code int32 + Name string +} + +var ( + Read = &Operation{ + Code: 4, + Name: "Read", + } + Write = &Operation{ + Code: 2, + Name: "Write", + } + Modify = &Operation{ + Code: 1, + Name: "Modify", + } +) + type User struct { Uid string `json:"uid"` UserID string `json:"dgraph.xid"` @@ -48,7 +74,7 @@ func UnmarshalUser(resp *api.Response, userKey string) (user *User, err error) { err = json.Unmarshal(resp.GetJson(), &m) if err != nil { - return nil, fmt.Errorf("Unable to unmarshal the query user response for user:%v", err) + return nil, fmt.Errorf("unable to unmarshal the query user response:%v", err) } users := m[userKey] if len(users) == 0 { @@ -63,18 +89,18 @@ func UnmarshalUser(resp *api.Response, userKey string) (user *User, err error) { // parse the response and check existing of the uid type Group struct { - Uid string `json:"uid"` - GroupID string `json:"dgraph.xid"` - Users []User `json:"~dgraph.user.group"` - Acls string `json:"dgraph.group.acl"` + Uid string `json:"uid"` + GroupID string `json:"dgraph.xid"` + Users []User `json:"~dgraph.user.group"` + Acls string `json:"dgraph.group.acl"` + MappedAcls map[string]int32 // only used in memory for acl enforcement } // Extract the first User pointed by the userKey in the query response func UnmarshalGroup(input []byte, groupKey string) (group *Group, err error) { m := make(map[string][]Group) - err = json.Unmarshal(input, &m) - if err != nil { + if err = json.Unmarshal(input, &m); err != nil { glog.Errorf("Unable to unmarshal the query group response:%v", err) return nil, err } @@ -84,11 +110,66 @@ func UnmarshalGroup(input []byte, groupKey string) (group *Group, err error) { return nil, nil } if len(groups) > 1 { - return nil, x.Errorf("Found multiple groups: %s", input) + return nil, fmt.Errorf("found multiple groups: %s", input) } + return &groups[0], nil } +// convert the acl blob to a map from predicates to permissions +func UnmarshalAcl(aclBytes []byte) (map[string]int32, error) { + var acls []Acl + if len(aclBytes) != 0 { + if err := json.Unmarshal(aclBytes, &acls); err != nil { + return nil, fmt.Errorf("unable to unmarshal the aclBytes: %v", err) + } + } + mappedAcls := make(map[string]int32) + for _, acl := range acls { + mappedAcls[acl.Predicate] = acl.Perm + } + return mappedAcls, nil +} + +// Extract a sequence of groups from the input +func UnmarshalGroups(input []byte, groupKey string) (group []Group, err error) { + m := make(map[string][]Group) + + if err = json.Unmarshal(input, &m); err != nil { + glog.Errorf("Unable to unmarshal the query group response:%v", err) + return nil, err + } + groups := m[groupKey] + return groups, nil +} + type JwtGroup struct { Group string } + +func getClientWithAdminCtx(conf *viper.Viper) (*dgo.Dgraph, CloseFunc, error) { + adminPassword := conf.GetString(gPassword) + if len(adminPassword) == 0 { + fmt.Print("Enter groot password:") + password, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + return nil, func() {}, fmt.Errorf("error while reading password:%v", err) + } + adminPassword = string(password) + } + + dc, closeClient := getDgraphClient(conf) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + + cleanFunc := func() { + cancel() + closeClient() + } + + if err := dc.Login(ctx, x.GrootId, adminPassword); err != nil { + return dc, cleanFunc, fmt.Errorf("unable to login to the groot account %v", err) + } + glog.Infof("login successfully to the groot account") + // update the context so that it has the admin jwt token + return dc, cleanFunc, nil +} diff --git a/systest/bulk_live_cases_test.go b/systest/bulk_live_cases_test.go index 2b84e8124cd..d2de1b83509 100644 --- a/systest/bulk_live_cases_test.go +++ b/systest/bulk_live_cases_test.go @@ -20,8 +20,6 @@ import ( "os" "testing" "time" - - "github.com/dgraph-io/dgraph/x" ) // TODO: This test was used just to make sure some really basic examples work. @@ -285,8 +283,6 @@ func DONOTRUNTestGoldenData(t *testing.T) { err := matchExportCount(matchExport{ expectedRDF: 1120879, expectedSchema: 10, - dir: s.liveCluster.dir, - port: s.liveCluster.dgraphPortOffset + x.PortHTTP, }) if err != nil { t.Fatal(err) diff --git a/systest/bulk_live_fixture_test.go b/systest/bulk_live_fixture_test.go index b7579142e28..058aead39a6 100644 --- a/systest/bulk_live_fixture_test.go +++ b/systest/bulk_live_fixture_test.go @@ -30,6 +30,9 @@ import ( "testing" "time" + "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/x" + "github.com/pkg/errors" ) @@ -45,12 +48,19 @@ var rootDir = filepath.Join(os.TempDir(), "dgraph_systest") type suite struct { t *testing.T - - liveCluster *DgraphCluster - bulkCluster *DgraphCluster } func newSuite(t *testing.T, schema, rdfs string) *suite { + dg, close := x.GetDgraphClient() + defer close() + + err := dg.Alter(context.Background(), &api.Operation{ + DropAll: true, + }) + if err != nil { + t.Fatalf("Could not drop old data: %v", err) + } + if testing.Short() { t.Skip("Skipping system test with long runtime.") } @@ -83,44 +93,30 @@ func (s *suite) setup(schemaFile, rdfFile string) { makeDirEmpty(liveDir), ) - s.bulkCluster = NewDgraphCluster(bulkDir) - s.checkFatal(s.bulkCluster.StartZeroOnly()) - bulkCmd := exec.Command(os.ExpandEnv("$GOPATH/bin/dgraph"), "bulk", "-r", rdfFile, "-s", schemaFile, "--http", ":"+strconv.Itoa(freePort(0)), - "-z", ":"+s.bulkCluster.zeroPort, "-j=1", "-x=true", ) - bulkCmd.Stdout = os.Stdout - bulkCmd.Stderr = os.Stdout bulkCmd.Dir = bulkDir if err := bulkCmd.Run(); err != nil { s.cleanup() s.t.Fatalf("Bulkloader didn't run: %v\n", err) } - s.bulkCluster.zero.Process.Kill() - s.bulkCluster.zero.Wait() + s.checkFatal(os.Rename( filepath.Join(bulkDir, "out", "0", "p"), filepath.Join(bulkDir, "p"), )) - s.liveCluster = NewDgraphCluster(liveDir) - s.checkFatal(s.liveCluster.Start()) - s.checkFatal(s.bulkCluster.Start()) - liveCmd := exec.Command(os.ExpandEnv("$GOPATH/bin/dgraph"), "live", "--rdfs", rdfFile, "--schema", schemaFile, - "--dgraph", ":"+s.liveCluster.dgraphPort, - "--zero", ":"+s.liveCluster.zeroPort, + "--dgraph", ":9180", ) liveCmd.Dir = liveDir - liveCmd.Stdout = os.Stdout - liveCmd.Stderr = os.Stdout if err := liveCmd.Run(); err != nil { s.cleanup() s.t.Fatalf("Live Loader didn't run: %v\n", err) @@ -137,27 +133,22 @@ func makeDirEmpty(dir string) error { func (s *suite) cleanup() { // NOTE: Shouldn't raise any errors here or fail a test, since this is // called when we detect an error (don't want to mask the original problem). - if s.liveCluster != nil { - s.liveCluster.Close() - } - if s.bulkCluster != nil { - s.bulkCluster.Close() - } _ = os.RemoveAll(rootDir) } func (s *suite) testCase(query, wantResult string) func(*testing.T) { return func(t *testing.T) { - for _, cluster := range []*DgraphCluster{s.bulkCluster, s.liveCluster} { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - txn := cluster.client.NewTxn() - resp, err := txn.Query(ctx, query) - if err != nil { - t.Fatalf("Could not query: %v", err) - } - CompareJSON(t, wantResult, string(resp.GetJson())) + dg, close := x.GetDgraphClient() + defer close() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + txn := dg.NewTxn() + resp, err := txn.Query(ctx, query) + if err != nil { + t.Fatalf("Could not query: %v", err) } + CompareJSON(t, wantResult, string(resp.GetJson())) } } diff --git a/systest/cluster_setup_test.go b/systest/cluster_setup_test.go index 2b2e849bffc..5ddb077d1bf 100644 --- a/systest/cluster_setup_test.go +++ b/systest/cluster_setup_test.go @@ -64,8 +64,8 @@ func (d *DgraphCluster) StartZeroOnly() error { "--replicas", "3", ) d.zero.Dir = d.dir - d.zero.Stdout = os.Stdout - d.zero.Stderr = os.Stderr + //d.zero.Stdout = os.Stdout + //d.zero.Stderr = os.Stderr if err := d.zero.Start(); err != nil { return err @@ -89,8 +89,6 @@ func (d *DgraphCluster) Start() error { "--custom_tokenizers", d.TokenizerPluginsArg, ) d.dgraph.Dir = d.dir - d.dgraph.Stdout = os.Stdout - d.dgraph.Stderr = os.Stderr if err := d.dgraph.Start(); err != nil { return err } diff --git a/systest/cluster_test.go b/systest/cluster_test.go index 8be02422a90..0dc0638c909 100644 --- a/systest/cluster_test.go +++ b/systest/cluster_test.go @@ -180,8 +180,6 @@ func DONOTRUNTestClusterSnapshot(t *testing.T) { "--zero", ":"+cluster.zeroPort, ) liveCmd.Dir = tmpDir - liveCmd.Stdout = os.Stdout - liveCmd.Stderr = os.Stdout if err := liveCmd.Run(); err != nil { cluster.Close() t.Fatalf("Live Loader didn't run: %v\n", err) diff --git a/systest/loader_test.go b/systest/loader_test.go index 75e70b07b0d..c0b91b8f21a 100644 --- a/systest/loader_test.go +++ b/systest/loader_test.go @@ -44,8 +44,6 @@ func TestLoaderXidmap(t *testing.T) { "-x", "x", ) liveCmd.Dir = tmpDir - liveCmd.Stdout = os.Stdout - liveCmd.Stderr = os.Stdout if err := liveCmd.Run(); err != nil { cluster.Close() t.Fatalf("Live Loader didn't run: %v\n", err) @@ -114,6 +112,6 @@ func TestLoaderXidmap(t *testing.T) { if string(out) != expected { cluster.Close() - t.Fatalf("Export is not as expected.") + t.Fatalf("Export is not as expected. Want:%v\nGot:%v\n", expected, string(out)) } } diff --git a/systest/queries_test.go b/systest/queries_test.go index 264cd4897c4..5f33358e4d2 100644 --- a/systest/queries_test.go +++ b/systest/queries_test.go @@ -335,28 +335,7 @@ func SchemaQueryTest(t *testing.T, c *dgo.Dgraph) { "type": "string", "list": true }, - { - "predicate": "dgraph.group.acl", - "type": "string" - }, - { - "predicate": "dgraph.password", - "type": "password" - }, - { - "predicate": "dgraph.user.group", - "type": "uid", - "reverse": true, - "list": true - }, - { - "predicate": "dgraph.xid", - "type": "string", - "index": true, - "tokenizer": [ - "exact" - ] - }, +` + x.AclPredsJson + `, { "predicate": "name", "type": "string", @@ -541,28 +520,7 @@ func SchemaQueryTestHTTP(t *testing.T, c *dgo.Dgraph) { "type": "string", "list": true }, - { - "predicate": "dgraph.group.acl", - "type": "string" - }, - { - "predicate": "dgraph.password", - "type": "password" - }, - { - "predicate": "dgraph.user.group", - "type": "uid", - "reverse": true, - "list": true - }, - { - "predicate": "dgraph.xid", - "type": "string", - "index": true, - "tokenizer": [ - "exact" - ] - }, +` + x.AclPredsJson + `, { "predicate": "name", "type": "string", diff --git a/types/password.go b/types/password.go index 526bc4802f5..f40902978fd 100644 --- a/types/password.go +++ b/types/password.go @@ -28,7 +28,7 @@ const ( func Encrypt(plain string) (string, error) { if len(plain) < pwdLenLimit { - return "", x.Errorf("Password too short, i.e. should has at least 6 chars") + return "", x.Errorf("Password too short, i.e. should have at least 6 chars") } encrypted, err := bcrypt.GenerateFromPassword([]byte(plain), bcrypt.DefaultCost) diff --git a/worker/groups.go b/worker/groups.go index 787f59b68b8..b98ed2682be 100644 --- a/worker/groups.go +++ b/worker/groups.go @@ -158,6 +158,7 @@ func (g *groupi) proposeInitialSchema() { Predicate: "dgraph.xid", ValueType: pb.Posting_STRING, Directive: pb.SchemaUpdate_INDEX, + Upsert: true, Tokenizer: []string{"exact"}, }) diff --git a/x/x.go b/x/x.go index c01fd24aa34..2943e59cf06 100644 --- a/x/x.go +++ b/x/x.go @@ -21,6 +21,7 @@ import ( "bytes" "encoding/json" "fmt" + "log" "math" "math/rand" "net" @@ -32,6 +33,8 @@ import ( "strings" "time" + "github.com/dgraph-io/dgo" + "github.com/dgraph-io/dgo/protos/api" "go.opencensus.io/trace" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -73,6 +76,8 @@ const ( TlsClientCert = "client.crt" TlsClientKey = "client.key" + + GrootId = "groot" ) var ( @@ -89,7 +94,7 @@ var ( {"predicate":"dgraph.group.acl", "type":"string"}, {"predicate":"dgraph.password", "type":"password"}, {"reverse":true, "predicate":"dgraph.user.group", "type":"uid", "list":true}, -{"index":true, "tokenizer":["exact"], "predicate":"dgraph.xid", "type":"string"} +{"index":true, "tokenizer":["exact"], "predicate":"dgraph.xid", "type":"string", "upsert":true} ` Nilbyte []byte ) @@ -488,3 +493,25 @@ func SpanTimer(span *trace.Span, name string) func() { span.Annotatef(attrs, "End. Took %s", time.Since(start)) } } + +type CancelFunc func() + +const DgraphAlphaPort = 9180 + +func GetDgraphClient() (*dgo.Dgraph, CancelFunc) { + return GetDgraphClientOnPort(DgraphAlphaPort) +} + +func GetDgraphClientOnPort(alphaPort int) (*dgo.Dgraph, CancelFunc) { + conn, err := grpc.Dial(fmt.Sprintf("127.0.0.1:%d", alphaPort), grpc.WithInsecure()) + if err != nil { + log.Fatal("While trying to dial gRPC") + } + + dc := api.NewDgraphClient(conn) + return dgo.NewDgraphClient(dc), func() { + if err := conn.Close(); err != nil { + log.Printf("Error while closing connection:%v", err) + } + } +}