diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index fac1587da77..99df5120c8d 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -36,6 +36,8 @@ import ( "time" badgerpb "github.com/dgraph-io/badger/v3/pb" + "github.com/dgraph-io/dgraph/ee/audit" + "github.com/dgraph-io/dgo/v200/protos/api" "github.com/dgraph-io/dgraph/edgraph" "github.com/dgraph-io/dgraph/ee/enc" @@ -191,6 +193,14 @@ they form a Raft group and provide synchronous replication. `Cache percentages summing up to 100 for various caches (FORMAT: PostingListCache,PstoreBlockCache,PstoreIndexCache,WAL).`) + flag.String("audit", "", + `Various audit options. + dir=/path/to/audits to define the path where to store the audit logs. + compress=true/false to enabled the compression of old audit logs (default behaviour is false). + encrypt_file=enc/key/file enables the audit log encryption with the key path provided with the + flag. + Sample flag could look like --audit dir=aa;encrypt_file=/filepath;compress=true`) + // TLS configurations x.RegisterServerTLSFlags(flag) } @@ -379,6 +389,7 @@ func serveGRPC(l net.Listener, tlsCfg *tls.Config, closer *z.Closer) { grpc.MaxSendMsgSize(x.GrpcMaxSize), grpc.MaxConcurrentStreams(1000), grpc.StatsHandler(&ocgrpc.ServerHandler{}), + grpc.UnaryInterceptor(audit.AuditRequestGRPC), } if tlsCfg != nil { opt = append(opt, grpc.Creds(credentials.NewTLS(tlsCfg))) @@ -417,15 +428,18 @@ func setupServer(closer *z.Closer) { log.Fatal(err) } - http.HandleFunc("/query", queryHandler) - http.HandleFunc("/query/", queryHandler) - http.HandleFunc("/mutate", mutationHandler) - http.HandleFunc("/mutate/", mutationHandler) - http.HandleFunc("/commit", commitHandler) - http.HandleFunc("/alter", alterHandler) - http.HandleFunc("/health", healthCheck) - http.HandleFunc("/state", stateHandler) - http.HandleFunc("/jemalloc", x.JemallocHandler) + baseMux := http.NewServeMux() + http.Handle("/", audit.AuditRequestHttp(baseMux)) + + baseMux.HandleFunc("/query", queryHandler) + baseMux.HandleFunc("/query/", queryHandler) + baseMux.HandleFunc("/mutate", mutationHandler) + baseMux.HandleFunc("/mutate/", mutationHandler) + baseMux.HandleFunc("/commit", commitHandler) + baseMux.HandleFunc("/alter", alterHandler) + baseMux.HandleFunc("/health", healthCheck) + baseMux.HandleFunc("/state", stateHandler) + baseMux.HandleFunc("/jemalloc", x.JemallocHandler) // TODO: Figure out what this is for? http.HandleFunc("/debug/store", storeStatsHandler) @@ -451,8 +465,9 @@ func setupServer(closer *z.Closer) { var gqlHealthStore *admin.GraphQLHealthStore // Do not use := notation here because adminServer is a global variable. mainServer, adminServer, gqlHealthStore = admin.NewServers(introspection, &globalEpoch, closer) - http.Handle("/graphql", mainServer.HTTPHandler()) - http.HandleFunc("/probe/graphql", func(w http.ResponseWriter, r *http.Request) { + baseMux.Handle("/graphql", mainServer.HTTPHandler()) + baseMux.HandleFunc("/probe/graphql", func(w http.ResponseWriter, + r *http.Request) { healthStatus := gqlHealthStore.GetHealth() httpStatusCode := http.StatusOK if !healthStatus.Healthy { @@ -463,18 +478,19 @@ func setupServer(closer *z.Closer) { x.Check2(w.Write([]byte(fmt.Sprintf(`{"status":"%s","schemaUpdateCounter":%d}`, healthStatus.StatusMsg, atomic.LoadUint64(&globalEpoch))))) }) - http.Handle("/admin", allowedMethodsHandler(allowedMethods{ + baseMux.Handle("/admin", allowedMethodsHandler(allowedMethods{ http.MethodGet: true, http.MethodPost: true, http.MethodOptions: true, }, adminAuthHandler(adminServer.HTTPHandler()))) - http.Handle("/admin/schema", adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, + baseMux.Handle("/admin/schema", adminAuthHandler(http.HandlerFunc(func( + w http.ResponseWriter, r *http.Request) { adminSchemaHandler(w, r, adminServer) }))) - http.Handle("/admin/schema/validate", http.HandlerFunc(func(w http.ResponseWriter, + baseMux.HandleFunc("/admin/schema/validate", func(w http.ResponseWriter, r *http.Request) { schema := readRequest(w, r) w.Header().Set("Content-Type", "application/json") @@ -489,26 +505,28 @@ func setupServer(closer *z.Closer) { w.WriteHeader(http.StatusBadRequest) errs := strings.Split(strings.TrimSpace(err.Error()), "\n") x.SetStatusWithErrors(w, x.ErrorInvalidRequest, errs) - })) + }) - http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true}, + baseMux.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http. + MethodGet: true}, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { shutDownHandler(w, r, adminServer) })))) - http.Handle("/admin/draining", allowedMethodsHandler(allowedMethods{ + baseMux.Handle("/admin/draining", allowedMethodsHandler(allowedMethods{ http.MethodPut: true, http.MethodPost: true, }, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { drainingHandler(w, r, adminServer) })))) - http.Handle("/admin/export", allowedMethodsHandler(allowedMethods{http.MethodGet: true}, + baseMux.Handle("/admin/export", allowedMethodsHandler( + allowedMethods{http.MethodGet: true}, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { exportHandler(w, r, adminServer) })))) - http.Handle("/admin/config/cache_mb", allowedMethodsHandler(allowedMethods{ + baseMux.Handle("/admin/config/cache_mb", allowedMethodsHandler(allowedMethods{ http.MethodGet: true, http.MethodPut: true, }, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -520,10 +538,10 @@ func setupServer(closer *z.Closer) { glog.Infof("Bringing up GraphQL HTTP admin API at %s/admin", addr) // Add OpenCensus z-pages. - zpages.Handle(http.DefaultServeMux, "/z") + zpages.Handle(baseMux, "/z") - http.HandleFunc("/", homeHandler) - http.HandleFunc("/ui/keywords", keywordHandler) + baseMux.Handle("/", http.HandlerFunc(homeHandler)) + baseMux.Handle("/ui/keywords", http.HandlerFunc(keywordHandler)) // Initialize the servers. admin.ServerCloser.AddRunning(3) @@ -585,6 +603,8 @@ func run() { walCache := (cachePercent[3] * (totalCache << 20)) / 100 ctype, clevel := x.ParseCompression(Alpha.Conf.GetString("badger.compression")) + + conf := audit.GetAuditConf(Alpha.Conf.GetString("audit")) opts := worker.Options{ PostingDir: Alpha.Conf.GetString("postings"), WALDir: Alpha.Conf.GetString("wal"), @@ -597,6 +617,7 @@ func run() { MutationsMode: worker.AllowMutations, AuthToken: Alpha.Conf.GetString("auth_token"), + Audit: conf, } secretFile := Alpha.Conf.GetString("acl_secret_file") @@ -658,6 +679,8 @@ func run() { LudicrousConcurrency: Alpha.Conf.GetInt("ludicrous_concurrency"), TLSClientConfig: tlsClientConf, TLSServerConfig: tlsServerConf, + HmacSecret: opts.HmacSecret, + Audit: opts.Audit != nil, } x.WorkerConfig.Parse(Alpha.Conf) @@ -699,6 +722,9 @@ func run() { worker.InitServerState() + // Audit is enterprise feature. + x.Check(audit.InitAuditorIfNecessary(opts.Audit, worker.EnterpriseEnabled)) + if Alpha.Conf.GetBool("expose_trace") { // TODO: Remove this once we get rid of event logs. trace.AuthRequest = func(req *http.Request) (any, sensitive bool) { @@ -792,6 +818,8 @@ func run() { adminCloser.SignalAndWait() glog.Infoln("adminCloser closed.") + audit.Close() + worker.State.Dispose() x.RemoveCidFile() glog.Info("worker.State disposed.") diff --git a/dgraph/cmd/bulk/count_index.go b/dgraph/cmd/bulk/count_index.go index 334679b7451..763bed0e363 100644 --- a/dgraph/cmd/bulk/count_index.go +++ b/dgraph/cmd/bulk/count_index.go @@ -156,7 +156,7 @@ func (c *countIndexer) writeIndex(buf *z.Buffer) { encoder = codec.Encoder{BlockSize: 256, Alloc: alloc} pl.Reset() - // Flush out the buffer. + // flush out the buffer. if outBuf.LenNoPadding() > 4<<20 { x.Check(c.writer.Write(outBuf)) outBuf.Reset() diff --git a/dgraph/cmd/bulk/reduce.go b/dgraph/cmd/bulk/reduce.go index d84e00d8789..c656f29f870 100644 --- a/dgraph/cmd/bulk/reduce.go +++ b/dgraph/cmd/bulk/reduce.go @@ -292,7 +292,7 @@ func (r *reducer) writeTmpSplits(ci *countIndexer, wg *sync.WaitGroup) { } for i := 0; i < len(kvs.Kv); i += maxSplitBatchLen { - // Flush the write batch when the max batch length is reached to prevent the + // flush the write batch when the max batch length is reached to prevent the // value log from growing over the allowed limit. if splitBatchLen >= maxSplitBatchLen { x.Check(ci.splitWriter.Flush()) diff --git a/dgraph/cmd/live/batch.go b/dgraph/cmd/live/batch.go index d26386e1e29..2a56087e536 100644 --- a/dgraph/cmd/live/batch.go +++ b/dgraph/cmd/live/batch.go @@ -34,7 +34,6 @@ import ( "github.com/dgraph-io/badger/v3" "github.com/dgraph-io/dgo/v200" "github.com/dgraph-io/dgo/v200/protos/api" - "github.com/dgraph-io/dgraph/dgraph/cmd/zero" "github.com/dgraph-io/dgraph/gql" "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/tok" @@ -132,7 +131,7 @@ func handleError(err error, isRetry bool) { dur := time.Duration(1+rand.Intn(10)) * time.Minute fmt.Printf("Server is overloaded. Will retry after %s.\n", dur.Round(time.Minute)) time.Sleep(dur) - case err != zero.ErrConflict && err != dgo.ErrAborted: + case err != x.ErrConflict && err != dgo.ErrAborted: fmt.Printf("Error while mutating: %v s.Code %v\n", s.Message(), s.Code()) } } diff --git a/dgraph/cmd/root_ee.go b/dgraph/cmd/root_ee.go index b1bdc727ed2..5660d8c3450 100644 --- a/dgraph/cmd/root_ee.go +++ b/dgraph/cmd/root_ee.go @@ -14,6 +14,7 @@ package cmd import ( acl "github.com/dgraph-io/dgraph/ee/acl" + "github.com/dgraph-io/dgraph/ee/audit" "github.com/dgraph-io/dgraph/ee/backup" ) @@ -24,5 +25,6 @@ func init() { &backup.LsBackup, &backup.ExportBackup, &acl.CmdAcl, + &audit.CmdAudit, ) } diff --git a/dgraph/cmd/zero/license_ee.go b/dgraph/cmd/zero/license_ee.go index 95f8ef3c3d9..c9c714b14db 100644 --- a/dgraph/cmd/zero/license_ee.go +++ b/dgraph/cmd/zero/license_ee.go @@ -20,6 +20,8 @@ import ( "net/http" "time" + "github.com/dgraph-io/dgraph/ee/audit" + "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" @@ -91,6 +93,7 @@ func (n *node) updateEnterpriseState(closer *z.Closer) { active := time.Now().UTC().Before(expiry) if !active { n.server.expireLicense() + audit.Close() glog.Warningf("Your enterprise license has expired and enterprise features are " + "disabled. To continue using enterprise features, apply a valid license. To receive " + "a new license, contact us at https://dgraph.io/contact.") diff --git a/dgraph/cmd/zero/oracle.go b/dgraph/cmd/zero/oracle.go index 24771b9abf3..056815674e9 100644 --- a/dgraph/cmd/zero/oracle.go +++ b/dgraph/cmd/zero/oracle.go @@ -134,7 +134,7 @@ func (o *Oracle) commit(src *api.TxnContext) error { defer o.Unlock() if o.hasConflict(src) { - return ErrConflict + return x.ErrConflict } // We store src.Keys as string to ensure compatibility with all the various language clients we // have. But, really they are just uint64s encoded as strings. We use base 36 during creation of @@ -310,9 +310,6 @@ func (o *Oracle) MaxPending() uint64 { return o.maxAssigned } -// ErrConflict is returned when commit couldn't succeed due to conflicts. -var ErrConflict = errors.New("Transaction conflict") - // proposeTxn proposes a txn update, and then updates src to reflect the state // of the commit after proposal is run. func (s *Server) proposeTxn(ctx context.Context, src *api.TxnContext) error { diff --git a/dgraph/cmd/zero/raft.go b/dgraph/cmd/zero/raft.go index a8ec086bd83..a27d63d8f67 100644 --- a/dgraph/cmd/zero/raft.go +++ b/dgraph/cmd/zero/raft.go @@ -28,6 +28,7 @@ import ( "time" "github.com/dgraph-io/dgraph/conn" + "github.com/dgraph-io/dgraph/ee/audit" "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" @@ -404,6 +405,11 @@ func (n *node) applyProposal(e raftpb.Entry) (uint64, error) { // Check expiry and set enabled accordingly. expiry := time.Unix(state.License.ExpiryTs, 0).UTC() state.License.Enabled = time.Now().UTC().Before(expiry) + if state.License.Enabled && opts.audit != nil { + if err := audit.InitAuditor(opts.audit); err != nil { + glog.Errorf("error while initializing audit logs %+v", err) + } + } } if p.Snapshot != nil { if err := n.applySnapshot(p.Snapshot); err != nil { diff --git a/dgraph/cmd/zero/run.go b/dgraph/cmd/zero/run.go index bb45c5a3020..c6b09383a50 100644 --- a/dgraph/cmd/zero/run.go +++ b/dgraph/cmd/zero/run.go @@ -25,9 +25,12 @@ import ( "net/http" "os" "os/signal" + "path/filepath" "syscall" "time" + "github.com/dgraph-io/dgraph/ee/audit" + "go.opencensus.io/plugin/ocgrpc" otrace "go.opencensus.io/trace" "go.opencensus.io/zpages" @@ -54,6 +57,7 @@ type options struct { w string rebalanceInterval time.Duration tlsClientConfig *tls.Config + audit *audit.AuditConf } var opts options @@ -95,6 +99,15 @@ instances to achieve high-availability. flag.StringP("wal", "w", "zw", "Directory storing WAL.") flag.Duration("rebalance_interval", 8*time.Minute, "Interval for trying a predicate move.") flag.String("enterprise_license", "", "Path to the enterprise license file.") + + flag.String("audit", "", + `Various audit options. + dir=/path/to/audits to define the path where to store the audit logs. + compress=true/false to enabled the compression of old audit logs (default behaviour is false). + encrypt_file=enc/key/file enables the audit log encryption with the key path provided with the + flag. + Sample flag could look like --audit dir=aa;encrypt_file=/filepath;compress=true`) + // TLS configurations x.RegisterServerTLSFlags(flag) } @@ -118,6 +131,7 @@ func (st *state) serveGRPC(l net.Listener, store *raftwal.DiskStorage) { grpc.MaxSendMsgSize(x.GrpcMaxSize), grpc.MaxConcurrentStreams(1000), grpc.StatsHandler(&ocgrpc.ServerHandler{}), + grpc.UnaryInterceptor(audit.AuditRequestGRPC), } tlsConf, err := x.LoadServerTLSConfigForInternalPort(Zero.Conf) @@ -186,6 +200,7 @@ func run() { x.Check(err) raft := x.NewSuperFlag(Zero.Conf.GetString("raft")).MergeAndCheckDefault(raftDefault) + conf := audit.GetAuditConf(Zero.Conf.GetString("audit")) opts = options{ bindall: Zero.Conf.GetBool("bindall"), portOffset: Zero.Conf.GetInt("port_offset"), @@ -195,6 +210,7 @@ func run() { w: Zero.Conf.GetString("wal"), rebalanceInterval: Zero.Conf.GetDuration("rebalance_interval"), tlsClientConfig: tlsConf, + audit: conf, } glog.Infof("Setting Config to: %+v", opts) x.WorkerConfig.Parse(Zero.Conf) @@ -215,6 +231,15 @@ func run() { } } + if opts.audit != nil { + wd, err := filepath.Abs(opts.w) + x.Check(err) + ad, err := filepath.Abs(opts.audit.Dir) + x.Check(err) + x.AssertTruef(ad != wd, + "WAL and Audit directory cannot be the same ('%s').", opts.audit.Dir) + } + if opts.rebalanceInterval <= 0 { log.Fatalf("ERROR: Rebalance interval must be greater than zero. Found: %d", opts.rebalanceInterval) @@ -255,14 +280,17 @@ func run() { x.Check(err) go x.StartListenHttpAndHttps(httpListener, tlsCfg, st.zero.closer) - http.HandleFunc("/health", st.pingResponse) - http.HandleFunc("/state", st.getState) - http.HandleFunc("/removeNode", st.removeNode) - http.HandleFunc("/moveTablet", st.moveTablet) - http.HandleFunc("/assign", st.assign) - http.HandleFunc("/enterpriseLicense", st.applyEnterpriseLicense) - http.HandleFunc("/jemalloc", x.JemallocHandler) - zpages.Handle(http.DefaultServeMux, "/z") + baseMux := http.NewServeMux() + http.Handle("/", audit.AuditRequestHttp(baseMux)) + + baseMux.HandleFunc("/health", st.pingResponse) + baseMux.HandleFunc("/state", st.getState) + baseMux.HandleFunc("/removeNode", st.removeNode) + baseMux.HandleFunc("/moveTablet", st.moveTablet) + baseMux.HandleFunc("/assign", st.assign) + baseMux.HandleFunc("/enterpriseLicense", st.applyEnterpriseLicense) + baseMux.HandleFunc("/jemalloc", x.JemallocHandler) + zpages.Handle(baseMux, "/z") // This must be here. It does not work if placed before Grpc init. x.Check(st.node.initAndStartNode()) @@ -320,6 +348,9 @@ func run() { err = store.Close() glog.Infof("Raft WAL closed with err: %v\n", err) + + audit.Close() + st.zero.orc.close() glog.Infoln("All done. Goodbye!") } diff --git a/edgraph/server.go b/edgraph/server.go index dfa69243416..579effead7b 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -46,7 +46,6 @@ import ( "github.com/dgraph-io/dgo/v200/protos/api" "github.com/dgraph-io/dgraph/chunker" "github.com/dgraph-io/dgraph/conn" - "github.com/dgraph-io/dgraph/dgraph/cmd/zero" "github.com/dgraph-io/dgraph/ee" "github.com/dgraph-io/dgraph/gql" "github.com/dgraph-io/dgraph/posting" @@ -571,7 +570,7 @@ func (s *Server) doMutate(ctx context.Context, qc *queryContext, resp *api.Respo } if !qc.req.CommitNow { calculateMutationMetrics() - if err == zero.ErrConflict { + if err == x.ErrConflict { err = status.Error(codes.FailedPrecondition, err.Error()) } @@ -590,7 +589,7 @@ func (s *Server) doMutate(ctx context.Context, qc *queryContext, resp *api.Respo resp.Txn.Aborted = true _, _ = worker.CommitOverNetwork(ctx, resp.Txn) - if err == zero.ErrConflict { + if err == x.ErrConflict { // We have already aborted the transaction, so the error message should reflect that. return dgo.ErrAborted } diff --git a/ee/audit/audit.go b/ee/audit/audit.go new file mode 100644 index 00000000000..772259331a3 --- /dev/null +++ b/ee/audit/audit.go @@ -0,0 +1,39 @@ +// +build oss + +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package audit + +type AuditConf struct { + Dir string +} + +func GetAuditConf(conf string) *AuditConf { + return nil +} + +func InitAuditorIfNecessary(conf *AuditConf, eeEnabled func() bool) error { + return nil +} + +func InitAuditor(conf *AuditConf) error { + return nil +} + +func Close() { + return +} diff --git a/ee/audit/audit_ee.go b/ee/audit/audit_ee.go new file mode 100644 index 00000000000..890f4c9b2d5 --- /dev/null +++ b/ee/audit/audit_ee.go @@ -0,0 +1,188 @@ +// +build !oss + +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ + +package audit + +import ( + "io/ioutil" + "path/filepath" + "sync/atomic" + "time" + + "github.com/dgraph-io/ristretto/z" + + "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" +) + +const ( + defaultAuditConf = "dir=; compress=false; encrypt-file=" + defaultAuditFilename = "dgraph_audit.log" +) + +var auditEnabled uint32 + +type AuditConf struct { + Compress bool + Dir string + EncryptBytes []byte +} + +type AuditEvent struct { + User string + ServerHost string + ClientHost string + Endpoint string + ReqType string + Req string + Status string + QueryParams map[string][]string +} + +const ( + UnauthorisedUser = "UnauthorisedUser" + UnknownUser = "UnknownUser" + PoorManAuth = "PoorManAuth" + Grpc = "Grpc" + Http = "Http" +) + +var auditor *auditLogger = &auditLogger{} + +type auditLogger struct { + log *x.Logger + tick *time.Ticker + closer *z.Closer +} + +func GetAuditConf(conf string) *AuditConf { + if conf == "" { + return nil + } + auditFlag := x.NewSuperFlag(conf).MergeAndCheckDefault(defaultAuditConf) + dir := auditFlag.GetString("dir") + x.AssertTruef(dir != "", "dir flag is not provided for the audit logs") + encBytes, err := readAuditEncKey(auditFlag) + x.Check(err) + return &AuditConf{ + Compress: auditFlag.GetBool("compress"), + Dir: dir, + EncryptBytes: encBytes, + } +} + +func readAuditEncKey(conf *x.SuperFlag) ([]byte, error) { + encFile := conf.GetString("encrypt-file") + if encFile == "" { + return nil, nil + } + path, err := filepath.Abs(encFile) + if err != nil { + return nil, err + } + encKey, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + return encKey, nil +} + +// InitAuditorIfNecessary accepts conf and enterprise edition check function. +// This method keep tracks whether cluster is part of enterprise edition or not. +// It pools eeEnabled function every five minutes to check if the license is still valid or not. +func InitAuditorIfNecessary(conf *AuditConf, eeEnabled func() bool) error { + if conf == nil { + return nil + } + if eeEnabled() { + if err := InitAuditor(conf); err != nil { + return err + } + } + auditor.tick = time.NewTicker(time.Minute * 5) + auditor.closer = z.NewCloser(1) + go trackIfEEValid(conf, eeEnabled) + return nil +} + +// InitAuditor initializes the auditor. +// This method doesnt keep track of whether cluster is part of enterprise edition or not. +// Client has to keep track of that. +func InitAuditor(conf *AuditConf) error { + var err error + if auditor.log, err = x.InitLogger(conf.Dir, defaultAuditFilename, conf.EncryptBytes, + conf.Compress); err != nil { + return err + } + atomic.StoreUint32(&auditEnabled, 1) + glog.Infoln("audit logs are enabled") + return nil +} + +// trackIfEEValid tracks enterprise license of the cluster. +// Right now alpha doesn't know about the enterprise/licence. +// That's why we needed to track if the current node is part of enterprise edition cluster +func trackIfEEValid(conf *AuditConf, eeEnabledFunc func() bool) { + defer auditor.closer.Done() + var err error + for { + select { + case <-auditor.tick.C: + if !eeEnabledFunc() && atomic.CompareAndSwapUint32(&auditEnabled, 1, 0) { + glog.Infof("audit logs are disabled") + auditor.log.Sync() + auditor.log = nil + continue + } + + if atomic.LoadUint32(&auditEnabled) != 1 { + if auditor.log, err = x.InitLogger(conf.Dir, defaultAuditFilename, + conf.EncryptBytes, conf.Compress); err != nil { + continue + } + atomic.StoreUint32(&auditEnabled, 1) + glog.Infof("audit logs are enabled") + } + case <-auditor.closer.HasBeenClosed(): + return + } + } +} + +// Close stops the ticker and sync the pending logs in buffer. +// It also sets the log to nil, because its being called by zero when license expires. +// If license added, InitLogger will take care of the file. +func Close() { + if atomic.LoadUint32(&auditEnabled) == 0 { + return + } + if auditor.tick != nil { + auditor.tick.Stop() + } + if auditor.closer != nil { + auditor.closer.SignalAndWait() + } + auditor.log.Sync() + auditor.log = nil + glog.Infoln("audit logs are closed.") +} + +func (a *auditLogger) Audit(event *AuditEvent) { + a.log.AuditI(event.Endpoint, + "user", event.User, + "server", event.ServerHost, + "client", event.ClientHost, + "req_type", event.ReqType, + "req_body", event.Req, + "query_param", event.QueryParams, + "status", event.Status) +} diff --git a/ee/audit/interceptor.go b/ee/audit/interceptor.go new file mode 100644 index 00000000000..e663957434d --- /dev/null +++ b/ee/audit/interceptor.go @@ -0,0 +1,37 @@ +// +build oss + +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package audit + +import ( + "context" + "net/http" + + "google.golang.org/grpc" +) + +func AuditRequestGRPC(ctx context.Context, req interface{}, + info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return handler(ctx, req) +} + +func AuditRequestHttp(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) +} diff --git a/ee/audit/interceptor_ee.go b/ee/audit/interceptor_ee.go new file mode 100644 index 00000000000..25eb6e5df1e --- /dev/null +++ b/ee/audit/interceptor_ee.go @@ -0,0 +1,192 @@ +// +build !oss + +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ +package audit + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "strings" + "sync/atomic" + + "github.com/dgraph-io/dgraph/x" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + + "google.golang.org/grpc" +) + +const ( + maxReqLength = 4 << 10 // 4 KB +) + +var skipApis = map[string]bool{ + // raft server + "Heartbeat": true, + "RaftMessage": true, + "JoinCluster": true, + "IsPeer": true, + // zero server + "StreamMembership": true, + "UpdateMembership": true, + "Oracle": true, + "Timestamps": true, + "ShouldServe": true, + "Connect": true, + // health server + "Check": true, + "Watch": true, +} + +var skipEPs = map[string]bool{ + // list of endpoints that needs to be skipped + "/health": true, + "/jemalloc": true, + "/state": true, +} + +func AuditRequestGRPC(ctx context.Context, req interface{}, + info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + skip := func(method string) bool { + return skipApis[info.FullMethod[strings.LastIndex(info.FullMethod, "/")+1:]] + } + + if atomic.LoadUint32(&auditEnabled) == 0 || skip(info.FullMethod) { + return handler(ctx, req) + } + response, err := handler(ctx, req) + auditGrpc(ctx, req, info, err) + return response, err +} + +func AuditRequestHttp(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + skip := func(method string) bool { + return skipEPs[r.URL.Path] + } + + if atomic.LoadUint32(&auditEnabled) == 0 || skip(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + + rw := NewResponseWriter(w) + var buf bytes.Buffer + tee := io.TeeReader(r.Body, &buf) + r.Body = ioutil.NopCloser(tee) + next.ServeHTTP(rw, r) + r.Body = ioutil.NopCloser(bytes.NewReader(buf.Bytes())) + auditHttp(rw, r) + }) +} + +func auditGrpc(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, err error) { + clientHost := "" + if p, ok := peer.FromContext(ctx); ok { + clientHost = p.Addr.String() + } + + userId := "" + if md, ok := metadata.FromIncomingContext(ctx); ok { + if t := md.Get("accessJwt"); len(t) > 0 { + userId = getUserId(t[0], false) + } else if t := md.Get("auth-token"); len(t) > 0 { + userId = getUserId(t[0], true) + } + } + + cd := codes.Unknown + if serr, ok := status.FromError(err); ok { + cd = serr.Code() + } + auditor.Audit(&AuditEvent{ + User: userId, + ServerHost: x.WorkerConfig.MyAddr, + ClientHost: clientHost, + Endpoint: info.FullMethod, + ReqType: Grpc, + Req: truncate(fmt.Sprintf("%+v", req), maxReqLength), + Status: cd.String(), + }) +} + +func auditHttp(w *ResponseWriter, r *http.Request) { + rb, err := ioutil.ReadAll(r.Body) + if err != nil { + rb = []byte(err.Error()) + } + + userId := "" + if token := r.Header.Get("X-Dgraph-AccessToken"); token != "" { + userId = getUserId(token, false) + } else if token := r.Header.Get("X-Dgraph-AuthToken"); token != "" { + userId = getUserId(token, true) + } else { + userId = getUserId("", false) + } + auditor.Audit(&AuditEvent{ + User: userId, + ServerHost: x.WorkerConfig.MyAddr, + ClientHost: r.RemoteAddr, + Endpoint: r.URL.Path, + ReqType: Http, + Req: truncate(string(rb), maxReqLength), + Status: http.StatusText(w.statusCode), + QueryParams: r.URL.Query(), + }) +} + +func getUserId(token string, poorman bool) string { + if poorman { + return PoorManAuth + } + var userId string + var err error + if token == "" { + if x.WorkerConfig.AclEnabled { + userId = UnauthorisedUser + } + } else { + if userId, err = x.ExtractUserName(token); err != nil { + userId = UnknownUser + } + } + return userId +} + +type ResponseWriter struct { + http.ResponseWriter + statusCode int +} + +func NewResponseWriter(w http.ResponseWriter) *ResponseWriter { + // WriteHeader(int) is not called if our response implicitly returns 200 OK, so + // we default to that status code. + return &ResponseWriter{w, http.StatusOK} +} + +func (rw *ResponseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +func truncate(s string, l int) string { + if len(s) > l { + return s[:l] + } + return s +} diff --git a/ee/audit/run.go b/ee/audit/run.go new file mode 100644 index 00000000000..78545acb16d --- /dev/null +++ b/ee/audit/run.go @@ -0,0 +1,33 @@ +// +build oss + +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package audit + +import ( + "github.com/dgraph-io/dgraph/x" + "github.com/spf13/cobra" +) + +var CmdAudit x.SubCommand + +func init() { + CmdAudit.Cmd = &cobra.Command{ + Use: "audit", + Short: "Enterprise feature. Not supported in oss version", + } +} diff --git a/ee/audit/run_ee.go b/ee/audit/run_ee.go new file mode 100644 index 00000000000..a4d9ee5734c --- /dev/null +++ b/ee/audit/run_ee.go @@ -0,0 +1,116 @@ +// +build !oss + +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ + +package audit + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "errors" + "fmt" + "io/ioutil" + "os" + + "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var CmdAudit x.SubCommand + +func init() { + CmdAudit.Cmd = &cobra.Command{ + Use: "audit", + Short: "Dgraph audit tool", + } + + subcommands := initSubcommands() + for _, sc := range subcommands { + CmdAudit.Cmd.AddCommand(sc.Cmd) + sc.Conf = viper.New() + if err := sc.Conf.BindPFlags(sc.Cmd.Flags()); err != nil { + glog.Fatalf("Unable to bind flags for command %v: %v", sc, err) + } + if err := sc.Conf.BindPFlags(CmdAudit.Cmd.PersistentFlags()); err != nil { + glog.Fatalf( + "Unable to bind persistent flags from audit for command %v: %v", sc, err) + } + sc.Conf.SetEnvPrefix(sc.EnvPrefix) + } +} + +var decryptCmd x.SubCommand + +func initSubcommands() []*x.SubCommand { + decryptCmd.Cmd = &cobra.Command{ + Use: "decrypt", + Short: "Run Dgraph Audit tool to decrypt audit files", + Run: func(cmd *cobra.Command, args []string) { + if err := run(); err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + }, + } + + decFlags := decryptCmd.Cmd.Flags() + decFlags.String("in", "", "input file that needs to decrypted.") + decFlags.String("out", "audit_log_out.log", + "output file to which decrypted output will be dumped.") + decFlags.String("encryption_key_file", "", "path to encrypt files.") + return []*x.SubCommand{&decryptCmd} +} + +func run() error { + key, err := ioutil.ReadFile(decryptCmd.Conf.GetString("encryption_key_file")) + x.Check(err) + if key == nil { + return errors.New("no encryption key provided") + } + + file, err := os.Open(decryptCmd.Conf.GetString("in")) + x.Check(err) + defer file.Close() + + outfile, err := os.OpenFile(decryptCmd.Conf.GetString("out"), + os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.ModePerm) + x.Check(err) + defer outfile.Close() + + block, err := aes.NewCipher(key) + stat, err := os.Stat(decryptCmd.Conf.GetString("in")) + x.Check(err) + iv := make([]byte, aes.BlockSize) + x.Check2(file.ReadAt(iv, 0)) + + var iterator int64 = 16 + for { + content := make([]byte, binary.BigEndian.Uint32(iv[12:])) + x.Check2(file.ReadAt(content, iterator)) + iterator = iterator + int64(binary.BigEndian.Uint32(iv[12:])) + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(content, content) + x.Check2(outfile.Write(content)) + // if its the end of data. finish decrypting + if iterator >= stat.Size() { + break + } + x.Check2(file.ReadAt(iv[12:], iterator)) + iterator = iterator + 4 + } + glog.Infof("Decryption of Audit file %s is Done. Decrypted file is %s", + decryptCmd.Conf.GetString("in"), + decryptCmd.Conf.GetString("out")) + return nil +} diff --git a/ee/utils_ee.go b/ee/utils_ee.go index ad546478de7..178ff266022 100644 --- a/ee/utils_ee.go +++ b/ee/utils_ee.go @@ -37,5 +37,8 @@ func GetEEFeaturesList() []string { } else { ee = append(ee, "backup_restore") } + if x.WorkerConfig.Audit { + ee = append(ee, "audit") + } return ee } diff --git a/go.mod b/go.mod index 1b1f928a043..8f36219f404 100644 --- a/go.mod +++ b/go.mod @@ -60,6 +60,7 @@ require ( github.com/twpayne/go-geom v1.0.5 go.etcd.io/etcd v0.0.0-20190228193606-a943ad0ee4c9 go.opencensus.io v0.22.5 + go.uber.org/zap v1.16.0 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20201021035429-f5854403a974 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 diff --git a/go.sum b/go.sum index 658e11ca15a..4488b3ba348 100644 --- a/go.sum +++ b/go.sum @@ -97,8 +97,10 @@ github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8Nz github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e h1:Wf6HqHfScWJN9/ZjdUKyjop4mf3Qdd+1TvvltAvM3m8= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= +github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbpBpLoyyu8B6e44T7hJy6potg= github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cosiner/argv v0.1.0/go.mod h1:EusR6TucWKX+zFgtdUsKT2Cvg45K5rtpCcWz4hK06d8= github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= @@ -583,10 +585,21 @@ go.opencensus.io v0.22.5 h1:dntmOdLpSpHlVqbW5Eay97DelsZHe+55D+xC6i0dDS0= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.starlark.net v0.0.0-20190702223751-32f345186213/go.mod h1:c1/X6cHgvdXj6pUlmWKMkuqRnW4K8x2vwt6JAaaircg= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.16.0 h1:uFRZXykJGK9lLY4HtgSw44DnIcAM+kRBP7x5m+NpAOM= +go.uber.org/zap v1.16.0/go.mod h1:MA8QOfq0BHJwdXa996Y4dYkAqRKB8/1K1QMMZVaNZjQ= golang.org/x/arch v0.0.0-20190927153633-4e8777c89be4/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= golang.org/x/crypto v0.0.0-20180608092829-8ac0e0d97ce4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -615,6 +628,7 @@ golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= @@ -722,6 +736,8 @@ golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191127201027-ecd32218bd7f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -802,6 +818,7 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/graphql/admin/current_user.go b/graphql/admin/current_user.go index d7e92b1c854..6bbcc43ee13 100644 --- a/graphql/admin/current_user.go +++ b/graphql/admin/current_user.go @@ -22,10 +22,7 @@ import ( "github.com/dgraph-io/dgraph/gql" "github.com/dgraph-io/dgraph/graphql/resolve" "github.com/dgraph-io/dgraph/graphql/schema" - "github.com/dgraph-io/dgraph/worker" "github.com/dgraph-io/dgraph/x" - "github.com/dgrijalva/jwt-go" - "github.com/pkg/errors" ) type currentUserResolver struct { @@ -38,31 +35,7 @@ func extractName(ctx context.Context) (string, error) { return "", err } - // Code copied from access_ee.go. Couldn't put the code in x, because of dependency on - // worker. (worker.Config.HmacSecret) - token, err := jwt.Parse(accessJwt[0], func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, errors.Errorf("unexpected signing method: %v", - token.Header["alg"]) - } - return []byte(worker.Config.HmacSecret), nil - }) - - if err != nil { - return "", errors.Wrapf(err, "unable to parse jwt token") - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok || !token.Valid { - return "", errors.Errorf("claims in jwt token is not map claims") - } - - userId, ok := claims["userid"].(string) - if !ok { - return "", errors.Errorf("userid in claims is not a string:%v", userId) - } - - return userId, nil + return x.ExtractUserName(accessJwt[0]) } func (gsr *currentUserResolver) Rewrite(ctx context.Context, diff --git a/posting/list.go b/posting/list.go index d48f255e02c..c3116a0f1c5 100644 --- a/posting/list.go +++ b/posting/list.go @@ -31,7 +31,6 @@ import ( "github.com/dgraph-io/badger/v3/y" "github.com/dgraph-io/dgraph/algo" "github.com/dgraph-io/dgraph/codec" - "github.com/dgraph-io/dgraph/dgraph/cmd/zero" "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/schema" "github.com/dgraph-io/dgraph/types" @@ -501,7 +500,7 @@ func (l *List) addMutationInternal(ctx context.Context, txn *Txn, t *pb.Directed l.AssertLock() if txn.ShouldAbort() { - return zero.ErrConflict + return x.ErrConflict } mpost := NewPosting(t) diff --git a/systest/audit/audit_dir/.gitkeep b/systest/audit/audit_dir/.gitkeep new file mode 100644 index 00000000000..e69de29bb2d diff --git a/systest/audit/audit_dir/aa/.gitkeep b/systest/audit/audit_dir/aa/.gitkeep new file mode 100644 index 00000000000..e69de29bb2d diff --git a/systest/audit/audit_dir/za/.gitkeep b/systest/audit/audit_dir/za/.gitkeep new file mode 100644 index 00000000000..e69de29bb2d diff --git a/systest/audit/audit_test.go b/systest/audit/audit_test.go new file mode 100644 index 00000000000..eec0ea57c3a --- /dev/null +++ b/systest/audit/audit_test.go @@ -0,0 +1,136 @@ +/* + * Copyright 2017-2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package audit + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/dgraph-io/dgraph/testutil" + "github.com/stretchr/testify/require" +) + +func TestZeroAudit(t *testing.T) { + defer os.RemoveAll("audit_dir/za/dgraph_audit.log") + zeroCmd := map[string][]string{ + "/removeNode": []string{`--location`, "--request", "GET", + fmt.Sprintf("%s/removeNode?id=3&group=1", testutil.SockAddrZeroHttp)}, + "/assign": []string{"--location", "--request", "GET", + fmt.Sprintf("%s/assign?what=uids&num=100", testutil.SockAddrZeroHttp)}, + "/moveTablet": []string{"--location", "--request", "GET", + fmt.Sprintf("%s/moveTablet?tablet=name&group=2", testutil.SockAddrZeroHttp)}} + + msgs := make([]string, 0) + // logger is buffered. make calls in bunch so that dont want to wait for flush + for i := 0; i < 500; i++ { + for req, c := range zeroCmd { + msgs = append(msgs, req) + cmd := exec.Command("curl", c...) + if out, err := cmd.CombinedOutput(); err != nil { + fmt.Println(string(out)) + t.Fatal(err) + } + } + } + + verifyLogs(t, "./audit_dir/za/dgraph_audit.log", msgs) +} +func TestAlphaAudit(t *testing.T) { + defer os.Remove("audit_dir/aa/dgraph_audit.log") + testCommand := map[string][]string{ + "/admin": []string{"--location", "--request", "POST", + fmt.Sprintf("%s/admin", testutil.SockAddrHttp), + "--header", "Content-Type: application/json", + "--data-raw", `'{"query":"mutation {\n backup( +input: {destination: \"/Users/sankalanparajuli/work/backup\"}) {\n response {\n message\n code\n }\n }\n}\n","variables":{}}'`}, + + "/graphql": []string{"--location", "--request", "POST", fmt.Sprintf("%s/graphql", testutil.SockAddrHttp), + "--header", "Content-Type: application/json", + "--data-raw", `'{"query":"query {\n __schema {\n __typename\n }\n}","variables":{}}'`}, + + "/alter": []string{"-X", "POST", fmt.Sprintf("%s/alter", testutil.SockAddrHttp), "-d", + `'name: string @index(term) . + type Person { + name + }'`}, + "/query": []string{"-H", "'Content-Type: application/dql'", "-X", "POST", fmt.Sprintf("%s/query", testutil.SockAddrHttp), + "-d", `$' + { + balances(func: anyofterms(name, "Alice Bob")) { + uid + name + balance + } + }'`}, + "/mutate": []string{"-H", "'Content-Type: application/rdf'", "-X", + "POST", fmt.Sprintf("%s/mutate?startTs=4", testutil.SockAddrHttp), "-d", `$' + { + set { + <0x1> "110" . + <0x1> "Balance" . + <0x2> "60" . + <0x2> "Balance" . + } + } + '`}, + } + + msgs := make([]string, 0) + // logger is buffered. make calls in bunch so that dont want to wait for flush + for i := 0; i < 200; i++ { + for req, c := range testCommand { + msgs = append(msgs, req) + cmd := exec.Command("curl", c...) + if out, err := cmd.CombinedOutput(); err != nil { + fmt.Println(string(out)) + t.Fatal(err) + } + } + } + verifyLogs(t, "./audit_dir/aa/dgraph_audit.log", msgs) +} + +func verifyLogs(t *testing.T, path string, cmds []string) { + abs, err := filepath.Abs(path) + require.Nil(t, err) + f, err := os.Open(abs) + require.Nil(t, err) + + type log struct { + Msg string `json:"msg"` + } + logMap := make(map[string]bool) + + var fileScanner *bufio.Scanner + fileScanner = bufio.NewScanner(f) + for fileScanner.Scan() { + bytes := fileScanner.Bytes() + l := new(log) + _ = json.Unmarshal(bytes, l) + logMap[l.Msg] = true + } + for _, m := range cmds { + if !logMap[m] { + t.Fatalf("audit logs not present for command %s", m) + } + } +} diff --git a/systest/audit/docker-compose.yml b/systest/audit/docker-compose.yml new file mode 100644 index 00000000000..6f181dcdd4c --- /dev/null +++ b/systest/audit/docker-compose.yml @@ -0,0 +1,39 @@ +version: "3.5" +services: + alpha1: + image: dgraph/dgraph:latest + working_dir: /data/alpha1 + labels: + cluster: test + ports: + - "8080" + - "9080" + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + - type: bind + source: ./audit_dir/aa + target: /audit_dir + command: /gobin/dgraph alpha --my=alpha1:7080 --zero=zero1:5080 --logtostderr + --audit "dir=/audit_dir" -v=2 --whitelist=10.0.0.0/8,172.16.0.0/12,192.168.0.0/16 + zero1: + image: dgraph/dgraph:latest + working_dir: /data/zero1 + labels: + cluster: test + ports: + - "5080" + - "6080" + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + - type: bind + source: ./audit_dir/za + target: /audit_dir + command: /gobin/dgraph zero --raft="idx=1" --my=zero1:5080 --logtostderr -v=2 --bindall + --audit "dir=/audit_dir" +volumes: {} diff --git a/worker/config.go b/worker/config.go index ada465cee26..67eef58614c 100644 --- a/worker/config.go +++ b/worker/config.go @@ -21,7 +21,7 @@ import ( "time" bo "github.com/dgraph-io/badger/v3/options" - + "github.com/dgraph-io/dgraph/ee/audit" "github.com/dgraph-io/dgraph/x" ) @@ -70,6 +70,8 @@ type Options struct { CachePercentage string // CacheMb is the total memory allocated between all the caches. CacheMb int64 + + Audit *audit.AuditConf } // Config holds an instance of the server options.. @@ -94,7 +96,20 @@ func (opt *Options) validate() { x.Check(err) td, err := filepath.Abs(x.WorkerConfig.TmpDir) x.Check(err) - x.AssertTruef(pd != wd, "Posting and WAL directory cannot be the same ('%s').", opt.PostingDir) - x.AssertTruef(pd != td, "Posting and Tmp directory cannot be the same ('%s').", opt.PostingDir) - x.AssertTruef(wd != td, "WAL and Tmp directory cannot be the same ('%s').", opt.WALDir) + x.AssertTruef(pd != wd, + "Posting and WAL directory cannot be the same ('%s').", opt.PostingDir) + x.AssertTruef(pd != td, + "Posting and Tmp directory cannot be the same ('%s').", opt.PostingDir) + x.AssertTruef(wd != td, + "WAL and Tmp directory cannot be the same ('%s').", opt.WALDir) + if opt.Audit != nil { + ad, err := filepath.Abs(opt.Audit.Dir) + x.Check(err) + x.AssertTruef(ad != pd, + "Posting and Audit Directory cannot be the same ('%s').", opt.Audit.Dir) + x.AssertTruef(ad != wd, + "WAL and Audit directory cannot be the same ('%s').", opt.Audit.Dir) + x.AssertTruef(ad != td, + "Tmp and Audit directory cannot be the same ('%s').", opt.Audit.Dir) + } } diff --git a/worker/draft.go b/worker/draft.go index 4d5f844c12a..4404630eaa1 100644 --- a/worker/draft.go +++ b/worker/draft.go @@ -42,7 +42,6 @@ import ( "github.com/dgraph-io/badger/v3" bpb "github.com/dgraph-io/badger/v3/pb" "github.com/dgraph-io/dgraph/conn" - "github.com/dgraph-io/dgraph/dgraph/cmd/zero" "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/raftwal" @@ -481,7 +480,7 @@ func (n *node) applyMutations(ctx context.Context, proposal *pb.Proposal) (rerr txn := posting.Oracle().RegisterStartTs(m.StartTs) if txn.ShouldAbort() { span.Annotatef(nil, "Txn %d should abort.", m.StartTs) - return zero.ErrConflict + return x.ErrConflict } // Discard the posting lists from cache to release memory at the end. defer txn.Update() diff --git a/x/config.go b/x/config.go index 9a571fa4ced..ac4d7c11f40 100644 --- a/x/config.go +++ b/x/config.go @@ -87,6 +87,8 @@ type WorkerOptions struct { StrictMutations bool // AclEnabled indicates whether the enterprise ACL feature is turned on. AclEnabled bool + // HmacSecret stores the secret used to sign JSON Web Tokens (JWT). + HmacSecret SensitiveByteSlice // AbortOlderThan tells Dgraph to discard transactions that are older than this duration. AbortOlderThan time.Duration // ProposedGroupId will be used if there's a file in the p directory called group_id with the @@ -107,6 +109,9 @@ type WorkerOptions struct { LogRequest int32 // If true, we should call msync or fsync after every write to survive hard reboots. HardSync bool + + // Audit contains the audit flags that enables the audit. + Audit bool } // WorkerConfig stores the global instance of the worker package's options. diff --git a/x/flags.go b/x/flags.go index 06c0d37db25..718858cc0af 100644 --- a/x/flags.go +++ b/x/flags.go @@ -113,14 +113,8 @@ func (sf *SuperFlag) MergeAndCheckDefault(flag string) *SuperFlag { } return sf } -func (sf *SuperFlag) Get(opt string) string { - if sf == nil { - return "" - } - return sf.m[opt] -} func (sf *SuperFlag) GetBool(opt string) bool { - val := sf.Get(opt) + val := sf.GetString(opt) if val == "" { return false } @@ -129,7 +123,7 @@ func (sf *SuperFlag) GetBool(opt string) bool { return b } func (sf *SuperFlag) GetUint64(opt string) uint64 { - val := sf.Get(opt) + val := sf.GetString(opt) if val == "" { return 0 } @@ -138,7 +132,7 @@ func (sf *SuperFlag) GetUint64(opt string) uint64 { return u } func (sf *SuperFlag) GetUint32(opt string) uint32 { - val := sf.Get(opt) + val := sf.GetString(opt) if val == "" { return 0 } @@ -146,3 +140,9 @@ func (sf *SuperFlag) GetUint32(opt string) uint32 { Checkf(err, "Unable to parse %s as uint32 for key: %s. Options: %s\n", val, opt, sf) return uint32(u) } +func (sf *SuperFlag) GetString(opt string) string { + if sf == nil { + return "" + } + return sf.m[opt] +} diff --git a/x/flags_test.go b/x/flags_test.go index 13e2b8c8a78..d126310a44b 100644 --- a/x/flags_test.go +++ b/x/flags_test.go @@ -39,6 +39,6 @@ func TestFlag(t *testing.T) { require.Panics(t, c) require.Equal(t, true, sf.GetBool("bool-key")) require.Equal(t, uint64(5), sf.GetUint64("int-key")) - require.Equal(t, "value", sf.Get("string-key")) + require.Equal(t, "value", sf.GetString("string-key")) require.Equal(t, uint64(5), sf.GetUint64("other-key")) } diff --git a/x/jwt_helper.go b/x/jwt_helper.go new file mode 100644 index 00000000000..a8b6d3472aa --- /dev/null +++ b/x/jwt_helper.go @@ -0,0 +1,48 @@ +/* + * Copyright 2017-2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package x + +import ( + "github.com/dgrijalva/jwt-go" + "github.com/pkg/errors" +) + +func ExtractUserName(jwtToken string) (string, error) { + token, err := jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.Errorf("unexpected signing method: %v", + token.Header["alg"]) + } + return []byte(WorkerConfig.HmacSecret), nil + }) + + if err != nil { + return "", errors.Wrapf(err, "unable to parse jwt token") + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return "", errors.Errorf("claims in jwt token is not map claims") + } + + userId, ok := claims["userid"].(string) + if !ok { + return "", errors.Errorf("userid in claims is not a string:%v", userId) + } + + return userId, nil +} diff --git a/x/log_writer.go b/x/log_writer.go new file mode 100644 index 00000000000..690074f59cc --- /dev/null +++ b/x/log_writer.go @@ -0,0 +1,358 @@ +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package x + +import ( + "bufio" + "compress/gzip" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "math/rand" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/dgraph-io/ristretto/z" + + "github.com/dgraph-io/badger/v3/y" +) + +const ( + backupTimeFormat = "2006-01-02T15-04-05.000" + bufferSize = 256 * 1024 + flushInterval = 10 * time.Second +) + +// This is done to ensure LogWriter always implement io.WriterCloser +var _ io.WriteCloser = (*LogWriter)(nil) + +type LogWriter struct { + FilePath string + MaxSize int64 + MaxAge int // number of days + Compress bool + EncryptionKey []byte + + baseIv [12]byte + mu sync.Mutex + size int64 + file *os.File + writer *bufio.Writer + flushTicker *time.Ticker + closer *z.Closer + // To manage order of cleaning old logs files + manageChannel chan bool +} + +func (l *LogWriter) Init() (*LogWriter, error) { + l.manageOldLogs() + if err := l.open(); err != nil { + return nil, fmt.Errorf("not able to create new file %v", err) + } + l.closer = z.NewCloser(2) + l.manageChannel = make(chan bool, 1) + go func() { + defer l.closer.Done() + for { + select { + case <-l.manageChannel: + l.manageOldLogs() + case <-l.closer.HasBeenClosed(): + return + } + } + }() + + l.flushTicker = time.NewTicker(flushInterval) + go l.flushPeriodic() + return l, nil +} + +func (l *LogWriter) Write(p []byte) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() + + if l.size+int64(len(p)) >= l.MaxSize*1024*1024 { + if err := l.rotate(); err != nil { + return 0, err + } + } + + // if encryption is enabled store the data in encyrpted way + if l.EncryptionKey != nil { + bytes, err := encrypt(l.EncryptionKey, l.baseIv, p) + if err != nil { + return 0, err + } + n, err := l.writer.Write(bytes) + l.size = l.size + int64(n) + return n, err + } + + n, err := l.writer.Write(p) + l.size = l.size + int64(n) + return n, err +} + +func (l *LogWriter) Close() error { + // close all go routines first before acquiring the lock to avoid contention + l.closer.SignalAndWait() + + l.mu.Lock() + defer l.mu.Unlock() + if l.file == nil { + return nil + } + l.flush() + l.flushTicker.Stop() + close(l.manageChannel) + _ = l.file.Close() + l.writer = nil + l.file = nil + return nil +} + +// flushPeriodic periodically flushes the log file buffers. +func (l *LogWriter) flushPeriodic() { + defer l.closer.Done() + for { + select { + case <-l.flushTicker.C: + l.mu.Lock() + l.flush() + l.mu.Unlock() + case <-l.closer.HasBeenClosed(): + return + } + } +} + +// LogWriter should be locked while calling this +func (l *LogWriter) flush() { + _ = l.writer.Flush() + _ = l.file.Sync() +} + +func encrypt(key []byte, baseIv [12]byte, src []byte) ([]byte, error) { + iv := make([]byte, 16) + copy(iv, baseIv[:]) + binary.BigEndian.PutUint32(iv[12:], uint32(len(src))) + allocate, err := y.XORBlockAllocate(src, key, iv) + if err != nil { + return nil, err + } + allocate = append(iv[12:], allocate...) + return allocate, nil +} + +func (l *LogWriter) rotate() error { + l.flush() + if err := l.file.Close(); err != nil { + return err + } + + if _, err := os.Stat(l.FilePath); err == nil { + // move the existing file + newname := backupName(l.FilePath) + if err := os.Rename(l.FilePath, newname); err != nil { + return fmt.Errorf("can't rename log file: %s", err) + } + } + + l.manageChannel <- true + return l.open() +} + +func (l *LogWriter) open() error { + if err := os.MkdirAll(filepath.Dir(l.FilePath), 0755); err != nil { + return err + } + + size := func() int64 { + info, err := os.Stat(l.FilePath) + if err != nil { + return 0 + } + return info.Size() + } + + openNew := func() error { + f, err := os.OpenFile(l.FilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.ModePerm) + if err != nil { + return err + } + l.file = f + l.writer = bufio.NewWriterSize(f, bufferSize) + + if l.EncryptionKey != nil { + rand.Read(l.baseIv[:]) + if _, err = l.writer.Write(l.baseIv[:]); err != nil { + return err + } + } + l.size = size() + return nil + } + + info, err := os.Stat(l.FilePath) + if err != nil { // if any error try to open new log file itself + return openNew() + } + + // encryption is enabled and file is corrupted as not able to read the IV + if l.EncryptionKey != nil && info.Size() < 12 { + return openNew() + } + + f, err := os.OpenFile(l.FilePath, os.O_APPEND|os.O_RDWR, os.ModePerm) + if err != nil { + return openNew() + } + + if l.EncryptionKey != nil { + // If not able to read the baseIv, then this file might be corrupted. + // open the new file in that case + if _, err = f.ReadAt(l.baseIv[:], 0); err != nil { + _ = f.Close() + return openNew() + } + } + + l.file = f + l.writer = bufio.NewWriterSize(f, bufferSize) + l.size = size() + return nil +} + +func backupName(name string) string { + dir := filepath.Dir(name) + prefix, ext := prefixAndExt(name) + timestamp := time.Now().Format(backupTimeFormat) + return filepath.Join(dir, fmt.Sprintf("%s-%s%s", prefix, timestamp, ext)) +} + +func compress(src string) error { + f, err := os.Open(src) + if err != nil { + return err + } + + defer f.Close() + gzf, err := os.OpenFile(src+".gz", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm) + if err != nil { + return err + } + + defer gzf.Close() + gz := gzip.NewWriter(gzf) + defer gz.Close() + if _, err := io.Copy(gz, f); err != nil { + os.Remove(src + ".gz") + return err + } + // close the descriptors because we need to delete the file + if err := f.Close(); err != nil { + return err + } + if err := os.Remove(src); err != nil { + return err + } + return nil +} + +// this should be called in a serial order +func (l *LogWriter) manageOldLogs() { + toRemove, toKeep, err := processOldLogFiles(l.FilePath, l.MaxSize) + if err != nil { + return + } + + for _, f := range toRemove { + errRemove := os.Remove(filepath.Join(filepath.Dir(l.FilePath), f)) + if err == nil && errRemove != nil { + err = errRemove + } + } + + // if compression enabled do compress + if l.Compress { + for _, f := range toKeep { + // already compressed no need + if strings.HasSuffix(f, ".gz") { + continue + } + fn := filepath.Join(filepath.Dir(l.FilePath), f) + errCompress := compress(fn) + if err == nil && errCompress != nil { + err = errCompress + } + } + } + + if err != nil { + fmt.Printf("error while managing old log files %+v\n", err) + } +} + +func prefixAndExt(file string) (prefix, ext string) { + filename := filepath.Base(file) + ext = filepath.Ext(filename) + prefix = filename[:len(filename)-len(ext)] + return prefix, ext +} + +func processOldLogFiles(fp string, maxAge int64) ([]string, []string, error) { + dir := filepath.Dir(fp) + files, err := ioutil.ReadDir(dir) + if err != nil { + return nil, nil, fmt.Errorf("can't read log file directory: %s", err) + } + + defPrefix, defExt := prefixAndExt(fp) + // check only for old files. Those files have - before the time + defPrefix = defPrefix + "-" + toRemove := make([]string, 0) + toKeep := make([]string, 0) + + diff := time.Duration(int64(24*time.Hour) * int64(maxAge)) + cutoff := time.Now().Add(-1 * diff) + + for _, f := range files { + if f.IsDir() || // f is directory + !strings.HasPrefix(f.Name(), defPrefix) || // f doesnt start with prefix + !(strings.HasSuffix(f.Name(), defExt) || strings.HasSuffix(f.Name(), defExt+".gz")) { + continue + } + + _, e := prefixAndExt(fp) + ts, err := time.Parse(backupTimeFormat, f.Name()[len(defPrefix):len(f.Name())-len(e)]) + if err != nil { + continue + } + if ts.Before(cutoff) { + toRemove = append(toRemove, f.Name()) + } else { + toKeep = append(toKeep, f.Name()) + } + } + + return toRemove, toKeep, nil +} diff --git a/x/log_writer_test.go b/x/log_writer_test.go new file mode 100644 index 00000000000..5f54434b060 --- /dev/null +++ b/x/log_writer_test.go @@ -0,0 +1,163 @@ +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package x + +import ( + "bufio" + "bytes" + "compress/gzip" + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestLogWriter(t *testing.T) { + path, _ := filepath.Abs("./log_test/audit.log") + defer os.RemoveAll(filepath.Dir(path)) + lw := &LogWriter{ + FilePath: path, + MaxSize: 1, + MaxAge: 1, + Compress: false, + } + + lw, _ = lw.Init() + writeToLogWriterAndVerify(t, lw, path) +} + +func TestLogWriterWithCompression(t *testing.T) { + path, _ := filepath.Abs("./log_test/audit.log") + defer os.RemoveAll(filepath.Dir(path)) + lw := &LogWriter{ + FilePath: path, + MaxSize: 1, + MaxAge: 1, + Compress: true, + } + + lw, _ = lw.Init() + writeToLogWriterAndVerify(t, lw, path) +} + +// if this test failed and you changed anything, please check the dgraph audit decrypt command. +// The dgraph audit decrypt command uses the same decryption method +func TestLogWriterWithEncryption(t *testing.T) { + path, _ := filepath.Abs("./log_test/audit.log.enc") + defer os.RemoveAll(filepath.Dir(path)) + lw := &LogWriter{ + FilePath: path, + MaxSize: 1, + MaxAge: 1, + Compress: false, + EncryptionKey: []byte("1234567890123456"), + } + + lw, _ = lw.Init() + msg := []byte("abcd") + msg = bytes.Repeat(msg, 256) + msg[1023] = '\n' + for i := 0; i < 10000; i++ { + n, err := lw.Write(msg) + require.Nil(t, err) + require.Equal(t, n, len(msg)+4, "write length is not equal") + } + + time.Sleep(time.Second * 10) + require.NoError(t, lw.Close()) + file, err := os.Open(path) + require.Nil(t, err) + defer file.Close() + outPath, _ := filepath.Abs("./log_test/audit_out.log") + outfile, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + require.Nil(t, err) + defer outfile.Close() + + block, err := aes.NewCipher(lw.EncryptionKey) + stat, err := os.Stat(path) + require.Nil(t, err) + iv := make([]byte, aes.BlockSize) + _, err = file.ReadAt(iv, 0) + require.Nil(t, err) + + var iterator int64 = 16 + for { + content := make([]byte, binary.BigEndian.Uint32(iv[12:])) + _, err = file.ReadAt(content, iterator) + require.Nil(t, err) + iterator = iterator + int64(binary.BigEndian.Uint32(iv[12:])) + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(content, content) + //require.True(t, bytes.Equal(content, msg)) + _, err = outfile.Write(content) + require.Nil(t, err) + if iterator >= stat.Size() { + break + } + _, err = file.ReadAt(iv[12:], iterator) + require.Nil(t, err) + iterator = iterator + 4 + } +} + +func writeToLogWriterAndVerify(t *testing.T, lw *LogWriter, path string) { + msg := []byte("abcd") + msg = bytes.Repeat(msg, 256) + msg[1023] = '\n' + for i := 0; i < 10; i++ { + go func() { + for i := 0; i < 1000; i++ { + n, err := lw.Write(msg) + require.Nil(t, err) + require.Equal(t, n, len(msg), "write length is not equal") + } + }() + } + time.Sleep(time.Second * 10) + require.NoError(t, lw.Close()) + files, err := ioutil.ReadDir("./log_test") + require.Nil(t, err) + + lineCount := 0 + for _, f := range files { + file, _ := os.Open(filepath.Join(filepath.Dir(path), f.Name())) + + var fileScanner *bufio.Scanner + if strings.HasSuffix(file.Name(), ".gz") { + gz, err := gzip.NewReader(file) + require.NoError(t, err) + all, err := ioutil.ReadAll(gz) + require.NoError(t, err) + fileScanner = bufio.NewScanner(bytes.NewReader(all)) + gz.Close() + } else { + fileScanner = bufio.NewScanner(file) + } + for fileScanner.Scan() { + lineCount = lineCount + 1 + } + } + + require.Equal(t, lineCount, 10000) +} diff --git a/x/logger.go b/x/logger.go new file mode 100644 index 00000000000..08610805322 --- /dev/null +++ b/x/logger.go @@ -0,0 +1,90 @@ +/* + * Copyright 2021 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package x + +import ( + "os" + "path/filepath" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func InitLogger(dir string, filename string, key []byte, compress bool) (*Logger, error) { + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, err + } + if key != nil { + filename = filename + ".enc" + } + + path, err := filepath.Abs(filepath.Join(dir, filename)) + if err != nil { + return nil, err + } + w := &LogWriter{ + FilePath: path, + MaxSize: 100, + MaxAge: 10, + EncryptionKey: key, + Compress: compress, + } + if w, err = w.Init(); err != nil { + return nil, err + } + return &Logger{ + logger: zap.New(zapcore.NewCore(zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()), + zapcore.AddSync(w), zap.DebugLevel)), + writer: w, + }, nil +} + +type Logger struct { + logger *zap.Logger + writer *LogWriter +} + +// AuditI logs audit message as info. args are key value pairs with key as string value +func (l *Logger) AuditI(msg string, args ...interface{}) { + if l == nil { + return + } + flds := make([]zap.Field, 0) + for i := 0; i < len(args); i = i + 2 { + flds = append(flds, zap.Any(args[i].(string), args[i+1])) + } + l.logger.Info(msg, flds...) +} + +func (l *Logger) AuditE(msg string, args ...interface{}) { + if l == nil { + return + } + flds := make([]zap.Field, 0) + for i := 0; i < len(args); i = i + 2 { + flds = append(flds, zap.Any(args[i].(string), args[i+1])) + } + l.logger.Error(msg, flds...) +} + +func (l *Logger) Sync() { + if l == nil { + return + } + _ = l.logger.Sync() + _ = l.writer.Close() +} diff --git a/x/x.go b/x/x.go index 4a5c56be9e0..67e6d0851e7 100644 --- a/x/x.go +++ b/x/x.go @@ -68,6 +68,8 @@ var ( ErrNoJwt = errors.New("no accessJwt available") // ErrorInvalidLogin is returned when username or password is incorrect in login ErrorInvalidLogin = errors.New("invalid username or password") + // ErrConflict is returned when commit couldn't succeed due to conflicts. + ErrConflict = errors.New("Transaction conflict") ) const (