Skip to content

Commit

Permalink
Serialize parallel downloads, fixes #1774
Browse files Browse the repository at this point in the history
In case if multiple requests to get session
event data were issued to the auth server
at the same time, multiple download requests
were originated, and sometimes partial data
was returned.

This commit serializes downloads of the session
in the context of the same auth server.
  • Loading branch information
klizhentas committed May 2, 2018
1 parent 49c1198 commit a75034f
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 20 deletions.
82 changes: 77 additions & 5 deletions lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ type AuditLog struct {
// fileTime is a rounded (to a day, by default) timestamp of the
// currently opened file
fileTime time.Time

// activeDownloads helps to serialize simultaneous downloads
// from the session record server
activeDownloads map[string]context.Context

// ctx signals close of the audit log
ctx context.Context

// cancel triggers closing of the signal context
cancel context.CancelFunc
}

// AuditLogConfig specifies configuration for AuditLog server
Expand Down Expand Up @@ -190,12 +200,16 @@ func NewAuditLog(cfg AuditLogConfig) (*AuditLog, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
ctx, cancel := context.WithCancel(context.TODO())
al := &AuditLog{
playbackDir: filepath.Join(cfg.DataDir, PlaybackDir, SessionLogsDir, defaults.Namespace),
AuditLogConfig: cfg,
Entry: log.WithFields(log.Fields{
trace.Component: teleport.ComponentAuditLog,
}),
activeDownloads: make(map[string]context.Context),
ctx: ctx,
cancel: cancel,
}
loggers, err := ttlmap.New(defaults.AuditLogSessions,
ttlmap.CallOnExpire(al.closeSessionLogger), ttlmap.Clock(cfg.Clock))
Expand Down Expand Up @@ -340,7 +354,7 @@ func (l *AuditLog) processSlice(sl SessionLogger, slice *SessionSlice) error {
defer func() {
l.removeLogger(slice.SessionID)
if err := sl.Finalize(); err != nil {
log.Warningf("Failed to finalize logger: %v", trace.DebugReport(err))
log.Warningf("Failed to finalize logger: %v.", trace.DebugReport(err))
}
}()
}
Expand Down Expand Up @@ -443,6 +457,15 @@ func (idx *sessionIndex) eventsFile(afterN int) (int, error) {
return -1, trace.NotFound("%v not found", afterN)
}

// chunkFileNames returns file names of all session chunk files
func (idx *sessionIndex) chunkFileNames() []string {
fileNames := make([]string, len(idx.chunks))
for i := 0; i < len(idx.chunks); i++ {
fileNames[i] = idx.chunksFileName(i)
}
return fileNames
}

func (idx *sessionIndex) chunksFile(offset int64) (string, int64, error) {
for i := len(idx.chunks) - 1; i >= 0; i-- {
entry := idx.chunks[i]
Expand Down Expand Up @@ -521,9 +544,44 @@ func readIndexEntries(file *os.File, authServer string) (events []indexEntry, ch
return
}

// createOrGetDownload creates a new download sync entry for a given session,
// if there is no active download in progress, or returns an existing one.
// if the new context has been created, cancel function is returned as a
// second argument. Caller should call this function to signal that download has been
// completed or failed.
func (l *AuditLog) createOrGetDownload(path string) (context.Context, context.CancelFunc) {
l.Lock()
defer l.Unlock()
ctx, ok := l.activeDownloads[path]
if ok {
return ctx, nil
}
ctx, cancel := context.WithCancel(context.TODO())
l.activeDownloads[path] = ctx
return ctx, func() {
cancel()
l.Lock()
defer l.Unlock()
delete(l.activeDownloads, path)
}
}

func (l *AuditLog) downloadSession(namespace string, sid session.ID) error {
tarballPath := filepath.Join(l.playbackDir, string(sid)+".tar")

ctx, cancel := l.createOrGetDownload(tarballPath)
// means that another download is in progress, so simply wait until
// it finishes
if cancel == nil {
l.Debugf("Another download is in progress for %v, waiting until it gets completed.", sid)
select {
case <-ctx.Done():
return nil
case <-l.ctx.Done():
return trace.BadParameter("audit log is closing, aborting the download")
}
}
defer cancel()
_, err := os.Stat(tarballPath)
err = trace.ConvertSystemError(err)
if err == nil {
Expand All @@ -540,7 +598,7 @@ func (l *AuditLog) downloadSession(namespace string, sid session.ID) error {
return trace.ConvertSystemError(err)
}
defer tarball.Close()
if err := l.UploadHandler.Download(context.TODO(), sid, tarball); err != nil {
if err := l.UploadHandler.Download(l.ctx, sid, tarball); err != nil {
// remove partially downloaded tarball
os.Remove(tarball.Name())
return trace.Wrap(err)
Expand All @@ -552,12 +610,25 @@ func (l *AuditLog) downloadSession(namespace string, sid session.ID) error {
if err != nil {
return trace.ConvertSystemError(err)
}

if err := utils.Extract(tarball, l.playbackDir); err != nil {
return trace.Wrap(err)
}
// Extract every chunks file on disk while holding the context,
// otherwise parallel downloads will try to unpack the file at the same time.
idx, err := l.readSessionIndex(namespace, sid)
if err != nil {
return trace.Wrap(err)
}
for _, fileName := range idx.chunkFileNames() {
reader, err := l.unpackFile(fileName)
if err != nil {
return trace.Wrap(err)
}
if err := reader.Close(); err != nil {
l.Warningf("Failed to close file: %v.", err)
}
}
l.WithFields(log.Fields{"duration": time.Now().Sub(start)}).Debugf("Unpacked %v to %v.", tarballPath, l.playbackDir)

return nil
}

Expand Down Expand Up @@ -644,7 +715,7 @@ func (l *AuditLog) unpackFile(fileName string) (readSeekCloser, error) {
return nil, trace.ConvertSystemError(err)
}
// no new data has been added
if unpackedInfo.ModTime().After(packedInfo.ModTime()) {
if unpackedInfo.ModTime().Unix() >= packedInfo.ModTime().Unix() {
return os.OpenFile(unpackedFile, os.O_RDONLY, 0640)
}
}
Expand Down Expand Up @@ -1341,6 +1412,7 @@ func (l *AuditLog) Close() error {
log.Warningf("Close failure: %v", err)
}
}
l.cancel()
l.Lock()
defer l.Unlock()

Expand Down
67 changes: 52 additions & 15 deletions lib/events/auditlog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1233,24 +1233,61 @@ func (a *AuditTestSuite) TestForwardAndUpload(c *check.C) {
c.Fatalf("Timeout wating for the upload event")
}

// read the session bytes
history, err := alog.GetSessionEvents(defaults.Namespace, session.ID(sessionID), 0, true)
c.Assert(err, check.IsNil)
c.Assert(history, check.HasLen, 3)
compare := func() error {
history, err := alog.GetSessionEvents(defaults.Namespace, session.ID(sessionID), 0, true)
if err != nil {
return trace.Wrap(err)
}
if len(history) != 3 {
return trace.BadParameter("expected history of 3, got %v", len(history))
}

// make sure offsets were properly set (0 for the first event and 5 bytes for hello):
c.Assert(history[1][SessionByteOffset], check.Equals, float64(0))
c.Assert(history[1][SessionEventTimestamp], check.Equals, float64(0))
// make sure offsets were properly set (0 for the first event and 5 bytes for hello):
if history[1][SessionByteOffset].(float64) != float64(0) {
return trace.BadParameter("expected offset of 0, got %v", history[1][SessionByteOffset])
}
if history[1][SessionEventTimestamp].(float64) != float64(0) {
return trace.BadParameter("expected timestamp of 0, got %v", history[1][SessionEventTimestamp])
}

// fetch all bytes
buff, err := alog.GetSessionChunk(defaults.Namespace, session.ID(sessionID), 0, 5000)
c.Assert(err, check.IsNil)
c.Assert(string(buff), check.Equals, string(firstMessage))
// fetch all bytes
buff, err := alog.GetSessionChunk(defaults.Namespace, session.ID(sessionID), 0, 5000)
if err != nil {
return trace.Wrap(err)
}
if string(buff) != string(firstMessage) {
return trace.CompareFailed("%q != %q", string(buff), string(firstMessage))
}

// with offset
buff, err = alog.GetSessionChunk(defaults.Namespace, session.ID(sessionID), 2, 5000)
c.Assert(err, check.IsNil)
c.Assert(string(buff), check.Equals, string(firstMessage[2:]))
// with offset
buff, err = alog.GetSessionChunk(defaults.Namespace, session.ID(sessionID), 2, 5000)
if err != nil {
return trace.Wrap(err)
}
if string(buff) != string(firstMessage[2:]) {
return trace.CompareFailed("%q != %q", string(buff), string(firstMessage[2:]))
}
return nil
}

// trigger several parallel downloads, they should not fail
iterations := 50
resultsC := make(chan error, iterations)
for i := 0; i < iterations; i++ {
go func() {
resultsC <- compare()
}()
}

timeout := time.After(time.Second)
for i := 0; i < iterations; i++ {
select {
case err := <-resultsC:
c.Assert(err, check.IsNil)
case <-timeout:
c.Fatalf("timeout waiting for goroutines to finish")
}
}
}

func marshal(f EventFields) []byte {
Expand Down

0 comments on commit a75034f

Please sign in to comment.