From a75034f2561a510ef17ff2aa2004089a3173c3b6 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Tue, 1 May 2018 11:47:45 -0700 Subject: [PATCH] Serialize parallel downloads, fixes #1774 In case if multiple requests to get session event data were issued to the auth server at the same time, multiple download requests were originated, and sometimes partial data was returned. This commit serializes downloads of the session in the context of the same auth server. --- lib/events/auditlog.go | 82 ++++++++++++++++++++++++++++++++++--- lib/events/auditlog_test.go | 67 +++++++++++++++++++++++------- 2 files changed, 129 insertions(+), 20 deletions(-) diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index cfa2ed8c63c09..12682bd42c45a 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -95,6 +95,16 @@ type AuditLog struct { // fileTime is a rounded (to a day, by default) timestamp of the // currently opened file fileTime time.Time + + // activeDownloads helps to serialize simultaneous downloads + // from the session record server + activeDownloads map[string]context.Context + + // ctx signals close of the audit log + ctx context.Context + + // cancel triggers closing of the signal context + cancel context.CancelFunc } // AuditLogConfig specifies configuration for AuditLog server @@ -190,12 +200,16 @@ func NewAuditLog(cfg AuditLogConfig) (*AuditLog, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } + ctx, cancel := context.WithCancel(context.TODO()) al := &AuditLog{ playbackDir: filepath.Join(cfg.DataDir, PlaybackDir, SessionLogsDir, defaults.Namespace), AuditLogConfig: cfg, Entry: log.WithFields(log.Fields{ trace.Component: teleport.ComponentAuditLog, }), + activeDownloads: make(map[string]context.Context), + ctx: ctx, + cancel: cancel, } loggers, err := ttlmap.New(defaults.AuditLogSessions, ttlmap.CallOnExpire(al.closeSessionLogger), ttlmap.Clock(cfg.Clock)) @@ -340,7 +354,7 @@ func (l *AuditLog) processSlice(sl SessionLogger, slice *SessionSlice) error { defer func() { l.removeLogger(slice.SessionID) if err := sl.Finalize(); err != nil { - log.Warningf("Failed to finalize logger: %v", trace.DebugReport(err)) + log.Warningf("Failed to finalize logger: %v.", trace.DebugReport(err)) } }() } @@ -443,6 +457,15 @@ func (idx *sessionIndex) eventsFile(afterN int) (int, error) { return -1, trace.NotFound("%v not found", afterN) } +// chunkFileNames returns file names of all session chunk files +func (idx *sessionIndex) chunkFileNames() []string { + fileNames := make([]string, len(idx.chunks)) + for i := 0; i < len(idx.chunks); i++ { + fileNames[i] = idx.chunksFileName(i) + } + return fileNames +} + func (idx *sessionIndex) chunksFile(offset int64) (string, int64, error) { for i := len(idx.chunks) - 1; i >= 0; i-- { entry := idx.chunks[i] @@ -521,9 +544,44 @@ func readIndexEntries(file *os.File, authServer string) (events []indexEntry, ch return } +// createOrGetDownload creates a new download sync entry for a given session, +// if there is no active download in progress, or returns an existing one. +// if the new context has been created, cancel function is returned as a +// second argument. Caller should call this function to signal that download has been +// completed or failed. +func (l *AuditLog) createOrGetDownload(path string) (context.Context, context.CancelFunc) { + l.Lock() + defer l.Unlock() + ctx, ok := l.activeDownloads[path] + if ok { + return ctx, nil + } + ctx, cancel := context.WithCancel(context.TODO()) + l.activeDownloads[path] = ctx + return ctx, func() { + cancel() + l.Lock() + defer l.Unlock() + delete(l.activeDownloads, path) + } +} + func (l *AuditLog) downloadSession(namespace string, sid session.ID) error { tarballPath := filepath.Join(l.playbackDir, string(sid)+".tar") + ctx, cancel := l.createOrGetDownload(tarballPath) + // means that another download is in progress, so simply wait until + // it finishes + if cancel == nil { + l.Debugf("Another download is in progress for %v, waiting until it gets completed.", sid) + select { + case <-ctx.Done(): + return nil + case <-l.ctx.Done(): + return trace.BadParameter("audit log is closing, aborting the download") + } + } + defer cancel() _, err := os.Stat(tarballPath) err = trace.ConvertSystemError(err) if err == nil { @@ -540,7 +598,7 @@ func (l *AuditLog) downloadSession(namespace string, sid session.ID) error { return trace.ConvertSystemError(err) } defer tarball.Close() - if err := l.UploadHandler.Download(context.TODO(), sid, tarball); err != nil { + if err := l.UploadHandler.Download(l.ctx, sid, tarball); err != nil { // remove partially downloaded tarball os.Remove(tarball.Name()) return trace.Wrap(err) @@ -552,12 +610,25 @@ func (l *AuditLog) downloadSession(namespace string, sid session.ID) error { if err != nil { return trace.ConvertSystemError(err) } - if err := utils.Extract(tarball, l.playbackDir); err != nil { return trace.Wrap(err) } + // Extract every chunks file on disk while holding the context, + // otherwise parallel downloads will try to unpack the file at the same time. + idx, err := l.readSessionIndex(namespace, sid) + if err != nil { + return trace.Wrap(err) + } + for _, fileName := range idx.chunkFileNames() { + reader, err := l.unpackFile(fileName) + if err != nil { + return trace.Wrap(err) + } + if err := reader.Close(); err != nil { + l.Warningf("Failed to close file: %v.", err) + } + } l.WithFields(log.Fields{"duration": time.Now().Sub(start)}).Debugf("Unpacked %v to %v.", tarballPath, l.playbackDir) - return nil } @@ -644,7 +715,7 @@ func (l *AuditLog) unpackFile(fileName string) (readSeekCloser, error) { return nil, trace.ConvertSystemError(err) } // no new data has been added - if unpackedInfo.ModTime().After(packedInfo.ModTime()) { + if unpackedInfo.ModTime().Unix() >= packedInfo.ModTime().Unix() { return os.OpenFile(unpackedFile, os.O_RDONLY, 0640) } } @@ -1341,6 +1412,7 @@ func (l *AuditLog) Close() error { log.Warningf("Close failure: %v", err) } } + l.cancel() l.Lock() defer l.Unlock() diff --git a/lib/events/auditlog_test.go b/lib/events/auditlog_test.go index ca9a6e4e7abe9..1b8b666c5023f 100644 --- a/lib/events/auditlog_test.go +++ b/lib/events/auditlog_test.go @@ -1233,24 +1233,61 @@ func (a *AuditTestSuite) TestForwardAndUpload(c *check.C) { c.Fatalf("Timeout wating for the upload event") } - // read the session bytes - history, err := alog.GetSessionEvents(defaults.Namespace, session.ID(sessionID), 0, true) - c.Assert(err, check.IsNil) - c.Assert(history, check.HasLen, 3) + compare := func() error { + history, err := alog.GetSessionEvents(defaults.Namespace, session.ID(sessionID), 0, true) + if err != nil { + return trace.Wrap(err) + } + if len(history) != 3 { + return trace.BadParameter("expected history of 3, got %v", len(history)) + } - // make sure offsets were properly set (0 for the first event and 5 bytes for hello): - c.Assert(history[1][SessionByteOffset], check.Equals, float64(0)) - c.Assert(history[1][SessionEventTimestamp], check.Equals, float64(0)) + // make sure offsets were properly set (0 for the first event and 5 bytes for hello): + if history[1][SessionByteOffset].(float64) != float64(0) { + return trace.BadParameter("expected offset of 0, got %v", history[1][SessionByteOffset]) + } + if history[1][SessionEventTimestamp].(float64) != float64(0) { + return trace.BadParameter("expected timestamp of 0, got %v", history[1][SessionEventTimestamp]) + } - // fetch all bytes - buff, err := alog.GetSessionChunk(defaults.Namespace, session.ID(sessionID), 0, 5000) - c.Assert(err, check.IsNil) - c.Assert(string(buff), check.Equals, string(firstMessage)) + // fetch all bytes + buff, err := alog.GetSessionChunk(defaults.Namespace, session.ID(sessionID), 0, 5000) + if err != nil { + return trace.Wrap(err) + } + if string(buff) != string(firstMessage) { + return trace.CompareFailed("%q != %q", string(buff), string(firstMessage)) + } - // with offset - buff, err = alog.GetSessionChunk(defaults.Namespace, session.ID(sessionID), 2, 5000) - c.Assert(err, check.IsNil) - c.Assert(string(buff), check.Equals, string(firstMessage[2:])) + // with offset + buff, err = alog.GetSessionChunk(defaults.Namespace, session.ID(sessionID), 2, 5000) + if err != nil { + return trace.Wrap(err) + } + if string(buff) != string(firstMessage[2:]) { + return trace.CompareFailed("%q != %q", string(buff), string(firstMessage[2:])) + } + return nil + } + + // trigger several parallel downloads, they should not fail + iterations := 50 + resultsC := make(chan error, iterations) + for i := 0; i < iterations; i++ { + go func() { + resultsC <- compare() + }() + } + + timeout := time.After(time.Second) + for i := 0; i < iterations; i++ { + select { + case err := <-resultsC: + c.Assert(err, check.IsNil) + case <-timeout: + c.Fatalf("timeout waiting for goroutines to finish") + } + } } func marshal(f EventFields) []byte {