diff --git a/lib/srv/desktop/audit.go b/lib/srv/desktop/audit.go index 688c9c6126544..7845dcccbccbe 100644 --- a/lib/srv/desktop/audit.go +++ b/lib/srv/desktop/audit.go @@ -21,6 +21,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" @@ -29,316 +30,312 @@ import ( "github.com/gravitational/teleport/lib/tlsca" ) -func (s *WindowsService) onSessionStart(ctx context.Context, emitter events.Emitter, id *tlsca.Identity, startTime time.Time, windowsUser, sessionID string, desktop types.WindowsDesktop, err error) { - userMetadata := id.GetUserMetadata() - userMetadata.Login = windowsUser +// desktopSessionAuditor is used to build session-related events +// which are emitted to Teleport's audit log +type desktopSessionAuditor struct { + clock clockwork.Clock + + sessionID string + identity *tlsca.Identity + windowsUser string + desktop types.WindowsDesktop + + startTime time.Time + clusterName string + desktopServiceUUID string + + auditCache sharedDirectoryAuditCache +} + +func (s *WindowsService) newSessionAuditor( + sessionID string, + identity *tlsca.Identity, + windowsUser string, + desktop types.WindowsDesktop) *desktopSessionAuditor { + return &desktopSessionAuditor{ + clock: s.cfg.Clock, + + sessionID: sessionID, + identity: identity, + windowsUser: windowsUser, + desktop: desktop, + + startTime: s.cfg.Clock.Now().UTC().Round(time.Millisecond), + clusterName: s.clusterName, + desktopServiceUUID: s.cfg.Heartbeat.HostUUID, + + auditCache: newSharedDirectoryAuditCache(), + } +} + +func (d *desktopSessionAuditor) makeSessionStart(err error) *events.WindowsDesktopSessionStart { + userMetadata := d.identity.GetUserMetadata() + userMetadata.Login = d.windowsUser event := &events.WindowsDesktopSessionStart{ Metadata: events.Metadata{ Type: libevents.WindowsDesktopSessionStartEvent, Code: libevents.DesktopSessionStartCode, - ClusterName: s.clusterName, - Time: startTime, + ClusterName: d.clusterName, + Time: d.startTime, }, UserMetadata: userMetadata, SessionMetadata: events.SessionMetadata{ - SessionID: sessionID, - WithMFA: id.MFAVerified, + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktop.GetAddr(), + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, - Status: events.Status{ - Success: err == nil, - }, - WindowsDesktopService: s.cfg.Heartbeat.HostUUID, - DesktopName: desktop.GetName(), - DesktopAddr: desktop.GetAddr(), - Domain: desktop.GetDomain(), - WindowsUser: windowsUser, - DesktopLabels: desktop.GetAllLabels(), + Status: events.Status{Success: err == nil}, + WindowsDesktopService: d.desktopServiceUUID, + DesktopName: d.desktop.GetName(), + DesktopAddr: d.desktop.GetAddr(), + Domain: d.desktop.GetDomain(), + WindowsUser: d.windowsUser, + DesktopLabels: d.desktop.GetAllLabels(), } + if err != nil { event.Code = libevents.DesktopSessionStartFailureCode event.Error = trace.Unwrap(err).Error() event.UserMessage = err.Error() } - s.emit(ctx, emitter, event) -} -func (s *WindowsService) onSessionEnd(ctx context.Context, emitter events.Emitter, id *tlsca.Identity, startedAt time.Time, recorded bool, windowsUser, sid string, desktop types.WindowsDesktop) { - // Ensure audit cache gets cleaned up - s.auditCache.Delete(sessionID(sid)) + return event +} - userMetadata := id.GetUserMetadata() - userMetadata.Login = windowsUser +func (d *desktopSessionAuditor) makeSessionEnd(recorded bool) *events.WindowsDesktopSessionEnd { + userMetadata := d.identity.GetUserMetadata() + userMetadata.Login = d.windowsUser - event := &events.WindowsDesktopSessionEnd{ + return &events.WindowsDesktopSessionEnd{ Metadata: events.Metadata{ Type: libevents.WindowsDesktopSessionEndEvent, Code: libevents.DesktopSessionEndCode, - ClusterName: s.clusterName, + ClusterName: d.clusterName, }, UserMetadata: userMetadata, SessionMetadata: events.SessionMetadata{ - SessionID: sid, - WithMFA: id.MFAVerified, - }, - WindowsDesktopService: s.cfg.Heartbeat.HostUUID, - DesktopAddr: desktop.GetAddr(), - Domain: desktop.GetDomain(), - WindowsUser: windowsUser, - DesktopLabels: desktop.GetAllLabels(), - StartTime: startedAt, - EndTime: s.cfg.Clock.Now().UTC().Round(time.Millisecond), - DesktopName: desktop.GetName(), + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, + }, + WindowsDesktopService: d.desktopServiceUUID, + DesktopAddr: d.desktop.GetAddr(), + Domain: d.desktop.GetDomain(), + WindowsUser: d.windowsUser, + DesktopLabels: d.desktop.GetAllLabels(), + StartTime: d.startTime, + EndTime: d.clock.Now().UTC(), + DesktopName: d.desktop.GetName(), Recorded: recorded, // There can only be 1 participant, desktop sessions are not join-able. Participants: []string{userMetadata.User}, } - s.emit(ctx, emitter, event) } -func (s *WindowsService) onClipboardSend(ctx context.Context, emitter events.Emitter, id *tlsca.Identity, sessionID string, desktopAddr string, length int32) { - event := &events.DesktopClipboardSend{ +func (d *desktopSessionAuditor) makeClipboardSend(length int32) *events.DesktopClipboardSend { + return &events.DesktopClipboardSend{ Metadata: events.Metadata{ Type: libevents.DesktopClipboardSendEvent, Code: libevents.DesktopClipboardSendCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: d.clusterName, + Time: d.clock.Now().UTC(), }, - UserMetadata: id.GetUserMetadata(), + UserMetadata: d.identity.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sessionID, - WithMFA: id.MFAVerified, + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, - DesktopAddr: desktopAddr, + DesktopAddr: d.desktop.GetAddr(), Length: length, } - s.emit(ctx, emitter, event) } -func (s *WindowsService) onClipboardReceive(ctx context.Context, emitter events.Emitter, id *tlsca.Identity, sessionID string, desktopAddr string, length int32) { - event := &events.DesktopClipboardReceive{ +func (d *desktopSessionAuditor) makeClipboardReceive(length int32) *events.DesktopClipboardReceive { + return &events.DesktopClipboardReceive{ Metadata: events.Metadata{ Type: libevents.DesktopClipboardReceiveEvent, Code: libevents.DesktopClipboardReceiveCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: d.clusterName, + Time: d.clock.Now().UTC(), }, - UserMetadata: id.GetUserMetadata(), + UserMetadata: d.identity.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sessionID, - WithMFA: id.MFAVerified, + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, - DesktopAddr: desktopAddr, + DesktopAddr: d.desktop.GetAddr(), Length: length, } - s.emit(ctx, emitter, event) } -// onSharedDirectoryAnnounce adds the shared directory's name to the auditCache. -func (s *WindowsService) onSharedDirectoryAnnounce( - ctx context.Context, - emitter events.Emitter, - id *tlsca.Identity, - sid string, - desktopAddr string, - m tdp.SharedDirectoryAnnounce, - tdpConn *tdp.Conn, -) { - if err := s.auditCache.SetName(sessionID(sid), directoryID(m.DirectoryID), directoryName(m.Name)); err != nil { - // An error means the audit cache entry for this sid exceeded its maximum allowable size. - errMsg := err.Error() - - // Close the connection as a security precaution. - if err := tdpConn.Close(); err != nil { - s.cfg.Log.WithError(err).Errorf("error when terminating sessionID(%v) for audit cache maximum size violation", sid) - } +// onSharedDirectoryAnnounce handles a shared directory announcement. +// In the happy path, no event is emitted here, but details from the announcement +// are cached for future audit events. An event is returned only if there was +// an error. +func (d *desktopSessionAuditor) onSharedDirectoryAnnounce(m tdp.SharedDirectoryAnnounce) *events.DesktopSharedDirectoryStart { + err := d.auditCache.SetName(directoryID(m.DirectoryID), directoryName(m.Name)) + if err == nil { + // no work to do yet, but data is cached for future events + return nil + } - event := &events.DesktopSharedDirectoryStart{ - Metadata: events.Metadata{ - Type: libevents.DesktopSharedDirectoryStartEvent, - Code: libevents.DesktopSharedDirectoryStartFailureCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), - }, - UserMetadata: id.GetUserMetadata(), - SessionMetadata: events.SessionMetadata{ - SessionID: sid, - WithMFA: id.MFAVerified, - }, - ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, - Protocol: libevents.EventProtocolTDP, - }, - Status: events.Status{ - Success: false, - Error: errMsg, - UserMessage: "Teleport failed the request and terminated the session as a security precaution", - }, - DesktopAddr: desktopAddr, - DirectoryName: m.Name, - DirectoryID: m.DirectoryID, - } + // An error means the audit cache exceeded its maximum allowable size. + errMsg := err.Error() - s.emit(ctx, emitter, event) + return &events.DesktopSharedDirectoryStart{ + Metadata: events.Metadata{ + Type: libevents.DesktopSharedDirectoryStartEvent, + Code: libevents.DesktopSharedDirectoryStartFailureCode, + ClusterName: d.clusterName, + Time: d.clock.Now().UTC(), + }, + UserMetadata: d.identity.GetUserMetadata(), + SessionMetadata: events.SessionMetadata{ + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, + }, + ConnectionMetadata: events.ConnectionMetadata{ + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), + Protocol: libevents.EventProtocolTDP, + }, + Status: events.Status{ + Success: false, + Error: errMsg, + UserMessage: "Teleport failed the request and terminated the session as a security precaution", + }, + DesktopAddr: d.desktop.GetAddr(), + DirectoryName: m.Name, + DirectoryID: m.DirectoryID, } } -// onSharedDirectoryAcknowledge emits a DesktopSharedDirectoryStart on a successful receipt of a -// successful tdp.SharedDirectoryAcknowledge. -func (s *WindowsService) onSharedDirectoryAcknowledge( - ctx context.Context, - emitter events.Emitter, - id *tlsca.Identity, - sid string, - desktopAddr string, - m tdp.SharedDirectoryAcknowledge, -) { +// makeSharedDirectoryStart creates a DesktopSharedDirectoryStart event. +func (d *desktopSessionAuditor) makeSharedDirectoryStart(m tdp.SharedDirectoryAcknowledge) *events.DesktopSharedDirectoryStart { code := libevents.DesktopSharedDirectoryStartCode - name, ok := s.auditCache.GetName(sessionID(sid), directoryID(m.DirectoryID)) + name, ok := d.auditCache.GetName(directoryID(m.DirectoryID)) if !ok { code = libevents.DesktopSharedDirectoryStartFailureCode name = "unknown" - s.cfg.Log.Warnf("failed to find a directory name corresponding to sessionID(%v), directoryID(%v)", sid, m.DirectoryID) } if m.ErrCode != tdp.ErrCodeNil { code = libevents.DesktopSharedDirectoryStartFailureCode } - event := &events.DesktopSharedDirectoryStart{ + return &events.DesktopSharedDirectoryStart{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryStartEvent, Code: code, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: d.clusterName, + Time: d.clock.Now().UTC(), }, - UserMetadata: id.GetUserMetadata(), + UserMetadata: d.identity.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, - WithMFA: id.MFAVerified, + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: statusFromErrCode(m.ErrCode), - DesktopAddr: desktopAddr, + DesktopAddr: d.desktop.GetAddr(), DirectoryName: string(name), DirectoryID: m.DirectoryID, } - - s.emit(ctx, emitter, event) } -// onSharedDirectoryReadRequest adds ReadRequestInfo to the auditCache. -func (s *WindowsService) onSharedDirectoryReadRequest( - ctx context.Context, - emitter events.Emitter, - id *tlsca.Identity, - sid string, - desktopAddr string, - m tdp.SharedDirectoryReadRequest, - tdpConn *tdp.Conn, -) { +// onSharedDirectoryReadRequest handles shared directory reads. +// In the happy path, no event is emitted here, but details from the operation +// are cached for future audit events. An event is returned only if there was +// an error. +func (d *desktopSessionAuditor) onSharedDirectoryReadRequest(m tdp.SharedDirectoryReadRequest) *events.DesktopSharedDirectoryRead { did := directoryID(m.DirectoryID) path := m.Path offset := m.Offset - if err := s.auditCache.SetReadRequestInfo(sessionID(sid), completionID(m.CompletionID), readRequestInfo{ + err := d.auditCache.SetReadRequestInfo(completionID(m.CompletionID), readRequestInfo{ directoryID: did, path: path, offset: offset, - }); err != nil { - // An error means the audit cache entry for this sid exceeded its maximum allowable size. - errMsg := err.Error() - - // Close the connection as a security precaution. - if err := tdpConn.Close(); err != nil { - s.cfg.Log.WithError(err).Errorf("error when terminating sessionID(%v) for audit cache maximum size violation", sid) - } - - name, ok := s.auditCache.GetName(sessionID(sid), did) - if !ok { - name = "unknown" - s.cfg.Log.Warnf("failed to find a directory name corresponding to sessionID(%v), directoryID(%v)", sid, did) - } + }) + if err == nil { + // no work to do yet, but data is cached for future events + return nil + } - event := &events.DesktopSharedDirectoryRead{ - Metadata: events.Metadata{ - Type: libevents.DesktopSharedDirectoryReadEvent, - Code: libevents.DesktopSharedDirectoryReadFailureCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), - }, - UserMetadata: id.GetUserMetadata(), - SessionMetadata: events.SessionMetadata{ - SessionID: sid, - WithMFA: id.MFAVerified, - }, - ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, - Protocol: libevents.EventProtocolTDP, - }, - Status: events.Status{ - Success: false, - Error: errMsg, - UserMessage: "Teleport failed the request and terminated the session as a security precaution", - }, - DesktopAddr: desktopAddr, - DirectoryName: string(name), - DirectoryID: uint32(did), - Path: path, - Length: m.Length, - Offset: offset, - } + name, ok := d.auditCache.GetName(did) + if !ok { + name = "unknown" + } - s.emit(ctx, emitter, event) + return &events.DesktopSharedDirectoryRead{ + Metadata: events.Metadata{ + Type: libevents.DesktopSharedDirectoryReadEvent, + Code: libevents.DesktopSharedDirectoryReadFailureCode, + ClusterName: d.clusterName, + Time: d.clock.Now().UTC(), + }, + UserMetadata: d.identity.GetUserMetadata(), + SessionMetadata: events.SessionMetadata{ + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, + }, + ConnectionMetadata: events.ConnectionMetadata{ + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), + Protocol: libevents.EventProtocolTDP, + }, + Status: events.Status{ + Success: false, + Error: err.Error(), + UserMessage: "Teleport failed the request and terminated the session as a security precaution", + }, + DesktopAddr: d.desktop.GetAddr(), + DirectoryName: string(name), + DirectoryID: uint32(did), + Path: path, + Length: m.Length, + Offset: offset, } } -// onSharedDirectoryReadResponse emits a DesktopSharedDirectoryRead audit event. -func (s *WindowsService) onSharedDirectoryReadResponse( - ctx context.Context, - emitter events.Emitter, - id *tlsca.Identity, - sid string, - desktopAddr string, - m tdp.SharedDirectoryReadResponse, -) { +// makeSharedDirectoryReadResponse creates a DesktopSharedDirectoryRead audit event. +func (d *desktopSessionAuditor) makeSharedDirectoryReadResponse(m tdp.SharedDirectoryReadResponse) *events.DesktopSharedDirectoryRead { var did directoryID + var name directoryName + var path string var offset uint64 - var name directoryName + code := libevents.DesktopSharedDirectoryReadCode + // Gather info from the audit cache - info, ok := s.auditCache.TakeReadRequestInfo(sessionID(sid), completionID(m.CompletionID)) + info, ok := d.auditCache.TakeReadRequestInfo(completionID(m.CompletionID)) if ok { did = info.directoryID - // Only search for the directory name if we retrieved the directoryID from the audit cache. - name, ok = s.auditCache.GetName(sessionID(sid), did) + // Only search for the directory name if we retrieved the directory ID from the audit cache. + name, ok = d.auditCache.GetName(did) if !ok { code = libevents.DesktopSharedDirectoryReadFailureCode name = "unknown" - s.cfg.Log.Warnf("failed to find a directory name corresponding to sessionID(%v), directoryID(%v)", sid, did) } path = info.path offset = info.offset @@ -346,133 +343,114 @@ func (s *WindowsService) onSharedDirectoryReadResponse( code = libevents.DesktopSharedDirectoryReadFailureCode path = "unknown" name = "unknown" - s.cfg.Log.Warnf("failed to find audit information corresponding to sessionID(%v), completionID(%v)", sid, m.CompletionID) } if m.ErrCode != tdp.ErrCodeNil { code = libevents.DesktopSharedDirectoryWriteFailureCode } - event := &events.DesktopSharedDirectoryRead{ + return &events.DesktopSharedDirectoryRead{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryReadEvent, Code: code, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: d.clusterName, + Time: d.clock.Now().UTC(), }, - UserMetadata: id.GetUserMetadata(), + UserMetadata: d.identity.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, - WithMFA: id.MFAVerified, + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: statusFromErrCode(m.ErrCode), - DesktopAddr: desktopAddr, + DesktopAddr: d.desktop.GetAddr(), DirectoryName: string(name), DirectoryID: uint32(did), Path: path, Length: m.ReadDataLength, Offset: offset, } - - s.emit(ctx, emitter, event) } -// onSharedDirectoryWriteRequest adds WriteRequestInfo to the auditCache. -func (s *WindowsService) onSharedDirectoryWriteRequest( - ctx context.Context, - emitter events.Emitter, - id *tlsca.Identity, - sid string, - desktopAddr string, - m tdp.SharedDirectoryWriteRequest, - tdpConn *tdp.Conn, -) { +// onSharedDirectoryWriteRequest handles shared directory writes. +// In the happy path, no event is emitted here, but details from the operation +// are cached for future audit events. An event is returned only if there was +// an error. +func (d *desktopSessionAuditor) onSharedDirectoryWriteRequest(m tdp.SharedDirectoryWriteRequest) *events.DesktopSharedDirectoryWrite { did := directoryID(m.DirectoryID) path := m.Path offset := m.Offset - if err := s.auditCache.SetWriteRequestInfo(sessionID(sid), completionID(m.CompletionID), writeRequestInfo{ - directoryID: did, - path: path, - offset: offset, - }); err != nil { - // An error means the audit cache entry for this sid exceeded its maximum allowable size. - errMsg := err.Error() - - // Close the connection as a security precaution. - if err := tdpConn.Close(); err != nil { - s.cfg.Log.WithError(err).Errorf("error when terminating sessionID(%v) for audit cache maximum size violation", sid) - } - - name, ok := s.auditCache.GetName(sessionID(sid), did) - if !ok { - name = "unknown" - s.cfg.Log.Warnf("failed to find a directory name corresponding to sessionID(%v), directoryID(%v)", sid, did) - } + err := d.auditCache.SetWriteRequestInfo( + completionID(m.CompletionID), + writeRequestInfo{ + directoryID: did, + path: path, + offset: offset, + }) + if err == nil { + // no work to do yet, but data is cached for future events + return nil + } - event := &events.DesktopSharedDirectoryWrite{ - Metadata: events.Metadata{ - Type: libevents.DesktopSharedDirectoryWriteEvent, - Code: libevents.DesktopSharedDirectoryWriteFailureCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), - }, - UserMetadata: id.GetUserMetadata(), - SessionMetadata: events.SessionMetadata{ - SessionID: sid, - WithMFA: id.MFAVerified, - }, - ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, - Protocol: libevents.EventProtocolTDP, - }, - Status: events.Status{ - Success: false, - Error: errMsg, - UserMessage: "Teleport failed the request and terminated the session as a security precaution", - }, - DesktopAddr: desktopAddr, - DirectoryName: string(name), - DirectoryID: uint32(did), - Path: path, - Length: m.WriteDataLength, - Offset: offset, - } + name, ok := d.auditCache.GetName(did) + if !ok { + name = "unknown" + } - s.emit(ctx, emitter, event) + return &events.DesktopSharedDirectoryWrite{ + Metadata: events.Metadata{ + Type: libevents.DesktopSharedDirectoryWriteEvent, + Code: libevents.DesktopSharedDirectoryWriteFailureCode, + ClusterName: d.clusterName, + Time: d.clock.Now().UTC(), + }, + UserMetadata: d.identity.GetUserMetadata(), + SessionMetadata: events.SessionMetadata{ + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, + }, + ConnectionMetadata: events.ConnectionMetadata{ + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), + Protocol: libevents.EventProtocolTDP, + }, + Status: events.Status{ + Success: false, + Error: err.Error(), + UserMessage: "Teleport failed the request and terminated the session as a security precaution", + }, + DesktopAddr: d.desktop.GetAddr(), + DirectoryName: string(name), + DirectoryID: uint32(did), + Path: path, + Length: m.WriteDataLength, + Offset: offset, } } -// onSharedDirectoryWriteResponse emits a DesktopSharedDirectoryWrite audit event. -func (s *WindowsService) onSharedDirectoryWriteResponse( - ctx context.Context, - emitter events.Emitter, - id *tlsca.Identity, - sid string, - desktopAddr string, - m tdp.SharedDirectoryWriteResponse, -) { +// makeSharedDirectoryWriteResponse creates a DesktopSharedDirectoryWrite audit event. +func (d *desktopSessionAuditor) makeSharedDirectoryWriteResponse(m tdp.SharedDirectoryWriteResponse) *events.DesktopSharedDirectoryWrite { var did directoryID + var name directoryName + var path string var offset uint64 - var name directoryName + code := libevents.DesktopSharedDirectoryWriteCode // Gather info from the audit cache - info, ok := s.auditCache.TakeWriteRequestInfo(sessionID(sid), completionID(m.CompletionID)) + info, ok := d.auditCache.TakeWriteRequestInfo(completionID(m.CompletionID)) if ok { did = info.directoryID // Only search for the directory name if we retrieved the directoryID from the audit cache. - name, ok = s.auditCache.GetName(sessionID(sid), did) + name, ok = d.auditCache.GetName(did) if !ok { code = libevents.DesktopSharedDirectoryWriteFailureCode name = "unknown" - s.cfg.Log.Warnf("failed to find a directory name corresponding to sessionID(%v), directoryID(%v)", sid, did) } path = info.path offset = info.offset @@ -480,44 +458,41 @@ func (s *WindowsService) onSharedDirectoryWriteResponse( code = libevents.DesktopSharedDirectoryWriteFailureCode path = "unknown" name = "unknown" - s.cfg.Log.Warnf("failed to find audit information corresponding to sessionID(%v), completionID(%v)", sid, m.CompletionID) } if m.ErrCode != tdp.ErrCodeNil { code = libevents.DesktopSharedDirectoryWriteFailureCode } - event := &events.DesktopSharedDirectoryWrite{ + return &events.DesktopSharedDirectoryWrite{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryWriteEvent, Code: code, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: d.clusterName, + Time: d.clock.Now().UTC(), }, - UserMetadata: id.GetUserMetadata(), + UserMetadata: d.identity.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, - WithMFA: id.MFAVerified, + SessionID: d.sessionID, + WithMFA: d.identity.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ - LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + LocalAddr: d.identity.LoginIP, + RemoteAddr: d.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: statusFromErrCode(m.ErrCode), - DesktopAddr: desktopAddr, + DesktopAddr: d.desktop.GetAddr(), DirectoryName: string(name), DirectoryID: uint32(did), Path: path, Length: m.BytesWritten, Offset: offset, } - - s.emit(ctx, emitter, event) } func (s *WindowsService) emit(ctx context.Context, emitter events.Emitter, event events.AuditEvent) { - if err := emitter.EmitAuditEvent(ctx, event); err != nil { + if err := s.cfg.Emitter.EmitAuditEvent(ctx, event); err != nil { s.cfg.Log.WithError(err).Errorf("Failed to emit audit event %v", event) } } @@ -547,7 +522,6 @@ func statusFromErrCode(errCode uint32) events.Status { Error: msg, UserMessage: msg, } - } const ( diff --git a/lib/srv/desktop/audit_cache.go b/lib/srv/desktop/audit_cache.go index f65df00b89aa5..60e6b5b400687 100644 --- a/lib/srv/desktop/audit_cache.go +++ b/lib/srv/desktop/audit_cache.go @@ -22,7 +22,6 @@ import ( "github.com/gravitational/trace" ) -type sessionID string type directoryID uint32 type completionID uint32 type directoryName string @@ -35,169 +34,110 @@ type readRequestInfo struct { type writeRequestInfo readRequestInfo -const ( - // entryMaxItems is the maximum number of items we want - // to allow in a single sharedDirectoryAuditCacheEntry. - // - // It's not a precise value, just one that should give us - // prevent the cache from growing too large due to a - // misbehaving client. - entryMaxItems = 2000 -) - -type sharedDirectoryAuditCacheEntry struct { - nameCache map[directoryID]directoryName - readRequestCache map[completionID]readRequestInfo - writeRequestCache map[completionID]writeRequestInfo -} - -func newSharedDirectoryAuditCacheEntry() *sharedDirectoryAuditCacheEntry { - return &sharedDirectoryAuditCacheEntry{ - nameCache: make(map[directoryID]directoryName), - readRequestCache: make(map[completionID]readRequestInfo), - writeRequestCache: make(map[completionID]writeRequestInfo), - } -} +// maxAuditCacheItems is the maximum number of items we want +// to allow in a single sharedDirectoryAuditCacheEntry. +// +// It's not a precise value, just one that should prevent the +// cache from growing too large due to a misbehaving client. +const maxAuditCacheItems = 2000 -// totalItems returns the total numbewr of items held in the entry. -func (e *sharedDirectoryAuditCacheEntry) totalItems() int { +// totalItems returns the total number of items held in the cache. +// The caller should hold a lock on the cache prior to calling this method. +func (e *sharedDirectoryAuditCache) totalItems() int { return len(e.nameCache) + len(e.readRequestCache) + len(e.writeRequestCache) } // sharedDirectoryAuditCache is a data structure for caching information -// from shared directory TDP messages so that it can be used later for +// from shared directory messages so that it can be used later for // creating shared directory audit events. type sharedDirectoryAuditCache struct { - m map[sessionID]*sharedDirectoryAuditCacheEntry sync.Mutex + + nameCache map[directoryID]directoryName + readRequestCache map[completionID]readRequestInfo + writeRequestCache map[completionID]writeRequestInfo } func newSharedDirectoryAuditCache() sharedDirectoryAuditCache { return sharedDirectoryAuditCache{ - m: make(map[sessionID]*sharedDirectoryAuditCacheEntry), - } -} - -// getInitialized gets an initialized sharedDirectoryAuditCacheEntry, mapped to sid. -// If an entry at sid already exists, it returns that, otherwise it returns an empty, initialized entry. -// -// This should be called at the start of any SetX method to ensure that we never get a -// "panic: assignment to entry in nil map". -// -// It is the responsibility of the caller to ensure that it has obtained the Lock before calling -// getInitialized, and that it calls Unlock once the entry returned by getInitialized is no longer going to -// be modified or otherwise used. -func (c *sharedDirectoryAuditCache) getInitialized(sid sessionID) (entry *sharedDirectoryAuditCacheEntry) { - entry, ok := c.m[sid] - - if !ok { - entry = newSharedDirectoryAuditCacheEntry() - c.m[sid] = entry + nameCache: make(map[directoryID]directoryName), + readRequestCache: make(map[completionID]readRequestInfo), + writeRequestCache: make(map[completionID]writeRequestInfo), } - - return entry } // SetName returns a non-nil error if the audit cache entry for sid exceeds its maximum size. // It is the responsibility of the caller to terminate the session if a non-nil error is returned. -func (c *sharedDirectoryAuditCache) SetName(sid sessionID, did directoryID, name directoryName) error { +func (c *sharedDirectoryAuditCache) SetName(did directoryID, name directoryName) error { c.Lock() defer c.Unlock() - entry := c.getInitialized(sid) - if entry.totalItems() >= entryMaxItems { - return trace.LimitExceeded("audit cache for sessionID(%v) exceeded maximum size", sid) + if c.totalItems() >= maxAuditCacheItems { + return trace.LimitExceeded("audit cache exceeded maximum size") } - entry.nameCache[did] = name - + c.nameCache[did] = name return nil } -// SetReadRequestInfo returns a non-nil error if the audit cache entry for sid exceeds its maximum size. +// SetReadRequestInfo returns a non-nil error if the audit cache exceeds its maximum size. // It is the responsibility of the caller to terminate the session if a non-nil error is returned. -func (c *sharedDirectoryAuditCache) SetReadRequestInfo(sid sessionID, cid completionID, info readRequestInfo) error { +func (c *sharedDirectoryAuditCache) SetReadRequestInfo(cid completionID, info readRequestInfo) error { c.Lock() defer c.Unlock() - entry := c.getInitialized(sid) - if entry.totalItems() >= entryMaxItems { - return trace.LimitExceeded("audit cache for sessionID(%v) exceeded maximum size", sid) + if c.totalItems() >= maxAuditCacheItems { + return trace.LimitExceeded("audit cache exceeded maximum size") } - entry.readRequestCache[cid] = info - + c.readRequestCache[cid] = info return nil } -// SetWriteRequestInfo returns a non-nil error if the audit cache entry for sid exceeds its maximum size. +// SetWriteRequestInfo returns a non-nil error if the audit cache exceeds its maximum size. // It is the responsibility of the caller to terminate the session if a non-nil error is returned. -func (c *sharedDirectoryAuditCache) SetWriteRequestInfo(sid sessionID, cid completionID, info writeRequestInfo) error { +func (c *sharedDirectoryAuditCache) SetWriteRequestInfo(cid completionID, info writeRequestInfo) error { c.Lock() defer c.Unlock() - entry := c.getInitialized(sid) - if entry.totalItems() >= entryMaxItems { - return trace.LimitExceeded("audit cache for sessionID(%v) exceeded maximum size", sid) + if c.totalItems() >= maxAuditCacheItems { + return trace.LimitExceeded("audit cache exceeded maximum size") } - entry.writeRequestCache[cid] = info - + c.writeRequestCache[cid] = info return nil } -func (c *sharedDirectoryAuditCache) GetName(sid sessionID, did directoryID) (name directoryName, ok bool) { +func (c *sharedDirectoryAuditCache) GetName(did directoryID) (name directoryName, ok bool) { c.Lock() defer c.Unlock() - entry, ok := c.m[sid] - if !ok { - return - } - - name, ok = entry.nameCache[did] + name, ok = c.nameCache[did] return } -// TakeReadRequestInfo gets the readRequestInfo for completion id cid of session id sid, +// TakeReadRequestInfo gets the readRequestInfo for completion ID cid, // removing the readRequestInfo from the cache in the process. -func (c *sharedDirectoryAuditCache) TakeReadRequestInfo(sid sessionID, cid completionID) (info readRequestInfo, ok bool) { +func (c *sharedDirectoryAuditCache) TakeReadRequestInfo(cid completionID) (info readRequestInfo, ok bool) { c.Lock() defer c.Unlock() - entry, ok := c.m[sid] - if !ok { - return - } - - info, ok = entry.readRequestCache[cid] + info, ok = c.readRequestCache[cid] if ok { - delete(entry.readRequestCache, cid) + delete(c.readRequestCache, cid) } return } -// TakeWriteRequestInfo gets the writeRequestInfo for completion id cid of session id sid, +// TakeWriteRequestInfo gets the writeRequestInfo for completion ID cid, // removing the writeRequestInfo from the cache in the process. -func (c *sharedDirectoryAuditCache) TakeWriteRequestInfo(sid sessionID, cid completionID) (info writeRequestInfo, ok bool) { +func (c *sharedDirectoryAuditCache) TakeWriteRequestInfo(cid completionID) (info writeRequestInfo, ok bool) { c.Lock() defer c.Unlock() - entry, ok := c.m[sid] - if !ok { - return - } - - info, ok = entry.writeRequestCache[cid] + info, ok = c.writeRequestCache[cid] if ok { - delete(entry.writeRequestCache, cid) + delete(c.writeRequestCache, cid) } return } - -func (c *sharedDirectoryAuditCache) Delete(sid sessionID) { - c.Lock() - defer c.Unlock() - - delete(c.m, sid) -} diff --git a/lib/srv/desktop/audit_test.go b/lib/srv/desktop/audit_test.go index 8fa78e8ad0049..c7e3e0eec5367 100644 --- a/lib/srv/desktop/audit_test.go +++ b/lib/srv/desktop/audit_test.go @@ -17,9 +17,6 @@ limitations under the License. package desktop import ( - "bytes" - "context" - "fmt" "io" "testing" "time" @@ -33,25 +30,51 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" libevents "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/srv/desktop/tdp" "github.com/gravitational/teleport/lib/tlsca" ) -func setup() (*WindowsService, *tlsca.Identity, *eventstest.MockEmitter) { - emitter := &eventstest.MockEmitter{} +const ( + testDirectoryID directoryID = 2 + testCompletionID completionID = 999 + + testOffset uint64 = 500 + testLength uint32 = 1000 + + testDirName = "test-dir" + testFilePath = "test/path/test-file.txt" +) + +// testDesktop is a dummy desktop used to populate +// audit events for testing +var testDesktop = &types.WindowsDesktopV3{ + ResourceHeader: types.ResourceHeader{ + Metadata: types.Metadata{ + Name: "test-desktop", + Labels: map[string]string{"env": "production"}, + }, + }, + Spec: types.WindowsDesktopSpecV3{ + Addr: "192.168.100.12", + Domain: "test.example.com", + }, +} + +func setup(desktop types.WindowsDesktop) (*tlsca.Identity, *desktopSessionAuditor) { log := logrus.New() log.SetOutput(io.Discard) + startTime := time.Now() + s := &WindowsService{ clusterName: "test-cluster", cfg: WindowsServiceConfig{ Log: log, - Emitter: emitter, + Emitter: libevents.NewDiscardEmitter(), Heartbeat: HeartbeatConfig{ HostUUID: "test-host-id", }, - Clock: clockwork.NewFakeClockAt(time.Now()), + Clock: clockwork.NewFakeClockAt(startTime), }, auditCache: newSharedDirectoryAuditCache(), } @@ -63,33 +86,36 @@ func setup() (*WindowsService, *tlsca.Identity, *eventstest.MockEmitter) { LoginIP: "127.0.0.1", } - return s, id, emitter + d := &desktopSessionAuditor{ + clock: s.cfg.Clock, + + sessionID: "sessionID", + identity: id, + windowsUser: "Administrator", + desktop: desktop, + + startTime: startTime, + clusterName: s.clusterName, + desktopServiceUUID: s.cfg.Heartbeat.HostUUID, + + auditCache: newSharedDirectoryAuditCache(), + } + + return id, d } func TestSessionStartEvent(t *testing.T) { - s, id, emitter := setup() - desktop := &types.WindowsDesktopV3{ - ResourceHeader: types.ResourceHeader{ - Metadata: types.Metadata{ - Name: "test-desktop", - Labels: map[string]string{"env": "production"}, - }, - }, - Spec: types.WindowsDesktopSpecV3{ - Addr: "192.168.100.12", - Domain: "test.example.com", - }, - } + id, audit := setup(testDesktop) userMeta := id.GetUserMetadata() userMeta.Login = "Administrator" expected := &events.WindowsDesktopSessionStart{ Metadata: events.Metadata{ - ClusterName: s.clusterName, + ClusterName: audit.clusterName, Type: libevents.WindowsDesktopSessionStartEvent, Code: libevents.DesktopSessionStartCode, - Time: s.cfg.Clock.Now().UTC().Round(time.Millisecond), + Time: audit.startTime, }, UserMetadata: userMeta, SessionMetadata: events.SessionMetadata{ @@ -98,16 +124,16 @@ func TestSessionStartEvent(t *testing.T) { }, ConnectionMetadata: events.ConnectionMetadata{ LocalAddr: id.LoginIP, - RemoteAddr: desktop.GetAddr(), + RemoteAddr: testDesktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: events.Status{ Success: true, }, - WindowsDesktopService: s.cfg.Heartbeat.HostUUID, + WindowsDesktopService: audit.desktopServiceUUID, DesktopName: "test-desktop", - DesktopAddr: desktop.GetAddr(), - Domain: desktop.GetDomain(), + DesktopAddr: testDesktop.GetAddr(), + Domain: testDesktop.GetDomain(), WindowsUser: "Administrator", DesktopLabels: map[string]string{"env": "production"}, } @@ -136,70 +162,25 @@ func TestSessionStartEvent(t *testing.T) { }, } { t.Run(test.desc, func(t *testing.T) { - s.onSessionStart( - context.Background(), - s.cfg.Emitter, - id, - s.cfg.Clock.Now().UTC().Round(time.Millisecond), - "Administrator", - "sessionID", - desktop, - test.err, - ) - - event := emitter.LastEvent() - require.NotNil(t, event) - - startEvent, ok := event.(*events.WindowsDesktopSessionStart) - require.True(t, ok) - + startEvent := audit.makeSessionStart(test.err) require.Empty(t, cmp.Diff(test.exp(), *startEvent)) }) } } func TestSessionEndEvent(t *testing.T) { - s, id, emitter := setup() - desktop := &types.WindowsDesktopV3{ - ResourceHeader: types.ResourceHeader{ - Metadata: types.Metadata{ - Name: "test-desktop", - Labels: map[string]string{"env": "production"}, - }, - }, - Spec: types.WindowsDesktopSpecV3{ - Addr: "192.168.100.12", - Domain: "test.example.com", - }, - } + id, audit := setup(testDesktop) - c := clockwork.NewFakeClockAt(time.Now()) - s.cfg.Clock = c - startTime := s.cfg.Clock.Now().UTC().Round(time.Millisecond) - c.Advance(30 * time.Second) - - s.onSessionEnd( - context.Background(), - s.cfg.Emitter, - id, - startTime, - true, - "Administrator", - "sessionID", - desktop, - ) - - event := emitter.LastEvent() - require.NotNil(t, event) - endEvent, ok := event.(*events.WindowsDesktopSessionEnd) - require.True(t, ok) + audit.clock.(clockwork.FakeClock).Advance(30 * time.Second) + + endEvent := audit.makeSessionEnd(true) userMeta := id.GetUserMetadata() userMeta.Login = "Administrator" expected := &events.WindowsDesktopSessionEnd{ Metadata: events.Metadata{ - ClusterName: s.clusterName, + ClusterName: audit.clusterName, Type: libevents.WindowsDesktopSessionEndEvent, Code: libevents.DesktopSessionEndCode, }, @@ -208,14 +189,14 @@ func TestSessionEndEvent(t *testing.T) { SessionID: "sessionID", WithMFA: id.MFAVerified, }, - WindowsDesktopService: s.cfg.Heartbeat.HostUUID, - DesktopAddr: desktop.GetAddr(), - Domain: desktop.GetDomain(), + WindowsDesktopService: audit.desktopServiceUUID, + DesktopAddr: testDesktop.GetAddr(), + Domain: testDesktop.GetDomain(), WindowsUser: "Administrator", DesktopLabels: map[string]string{"env": "production"}, - StartTime: startTime, - EndTime: c.Now().UTC().Round(time.Millisecond), - DesktopName: desktop.GetName(), + StartTime: audit.startTime, + EndTime: audit.clock.Now().UTC(), + DesktopName: testDesktop.GetName(), Recorded: true, Participants: []string{"foo"}, } @@ -223,15 +204,10 @@ func TestSessionEndEvent(t *testing.T) { } func TestDesktopSharedDirectoryStartEvent(t *testing.T) { - sid := "session-0" - desktopAddr := "windows.example.com" - testDirName := "test-dir" - var did uint32 = 2 - for _, test := range []struct { name string - // sendsSda determines whether a SharedDirectoryAnnounce is sent. - sendsSda bool + // sendsAnnounce determines whether a SharedDirectoryAnnounce is sent. + sendsAnnounce bool // errCode is the error code in the simulated SharedDirectoryAcknowledge errCode uint32 // expected returns the event we expect to be emitted by modifying baseEvent @@ -240,18 +216,18 @@ func TestDesktopSharedDirectoryStartEvent(t *testing.T) { }{ { // when everything is working as expected - name: "typical operation", - sendsSda: true, - errCode: tdp.ErrCodeNil, + name: "typical operation", + sendsAnnounce: true, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryStart) *events.DesktopSharedDirectoryStart { return baseEvent }, }, { // the announce operation failed - name: "announce failed", - sendsSda: true, - errCode: tdp.ErrCodeFailed, + name: "announce failed", + sendsAnnounce: true, + errCode: tdp.ErrCodeFailed, expected: func(baseEvent *events.DesktopSharedDirectoryStart) *events.DesktopSharedDirectoryStart { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryStartFailureCode return baseEvent @@ -259,9 +235,9 @@ func TestDesktopSharedDirectoryStartEvent(t *testing.T) { }, { // should never happen but just in case - name: "directory name unknown", - sendsSda: false, - errCode: tdp.ErrCodeNil, + name: "directory name unknown", + sendsAnnounce: false, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryStart) *events.DesktopSharedDirectoryStart { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryStartFailureCode baseEvent.DirectoryName = "unknown" @@ -270,82 +246,56 @@ func TestDesktopSharedDirectoryStartEvent(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - s, id, emitter := setup() - recvHandler := s.makeTDPReceiveHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, &tdp.Conn{}) - sendHandler := s.makeTDPSendHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, &tdp.Conn{}) - - if test.sendsSda { + id, audit := setup(testDesktop) + + if test.sendsAnnounce { // SharedDirectoryAnnounce initializes the nameCache. - sda := tdp.SharedDirectoryAnnounce{ - DirectoryID: did, + audit.onSharedDirectoryAnnounce(tdp.SharedDirectoryAnnounce{ + DirectoryID: uint32(testDirectoryID), Name: testDirName, - } - recvHandler(sda) + }) } // SharedDirectoryAcknowledge causes the event to be emitted - // (or not, on failure). - ack := tdp.SharedDirectoryAcknowledge{ - DirectoryID: did, + startEvent := audit.makeSharedDirectoryStart(tdp.SharedDirectoryAcknowledge{ + DirectoryID: uint32(testDirectoryID), ErrCode: test.errCode, - } - encoded, err := ack.Encode() - require.NoError(t, err) - sendHandler(ack, encoded) + }) baseEvent := &events.DesktopSharedDirectoryStart{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryStartEvent, Code: libevents.DesktopSharedDirectoryStartCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: audit.clusterName, + Time: audit.clock.Now().UTC(), }, UserMetadata: id.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, + SessionID: audit.sessionID, WithMFA: id.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + RemoteAddr: audit.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: statusFromErrCode(test.errCode), - DesktopAddr: desktopAddr, + DesktopAddr: audit.desktop.GetAddr(), DirectoryName: testDirName, - DirectoryID: did, + DirectoryID: uint32(testDirectoryID), } expected := test.expected(baseEvent) - event := emitter.LastEvent() - - require.NotNil(t, event) - startEvent, ok := event.(*events.DesktopSharedDirectoryStart) - require.True(t, ok) - require.Empty(t, cmp.Diff(expected, startEvent)) }) } } func TestDesktopSharedDirectoryReadEvent(t *testing.T) { - sid := "session-0" - desktopAddr := "windows.example.com" - testDirName := "test-dir" - path := "test/path/test-file.txt" - var did uint32 = 2 - var cid uint32 = 999 - var offset uint64 = 500 - var length uint32 = 1000 - for _, test := range []struct { name string - // sendsSda determines whether a SharedDirectoryAnnounce is sent. - sendsSda bool + // sendsAnnounce determines whether a SharedDirectoryAnnounce is sent. + sendsAnnounce bool // sendsReq determines whether a SharedDirectoryReadRequest is sent. sendsReq bool // errCode is the error code in the simulated SharedDirectoryReadResponse @@ -356,20 +306,20 @@ func TestDesktopSharedDirectoryReadEvent(t *testing.T) { }{ { // when everything is working as expected - name: "typical operation", - sendsSda: true, - sendsReq: true, - errCode: tdp.ErrCodeNil, + name: "typical operation", + sendsAnnounce: true, + sendsReq: true, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryRead) *events.DesktopSharedDirectoryRead { return baseEvent }, }, { // the read operation failed - name: "read failed", - sendsSda: true, - sendsReq: true, - errCode: tdp.ErrCodeFailed, + name: "read failed", + sendsAnnounce: true, + sendsReq: true, + errCode: tdp.ErrCodeFailed, expected: func(baseEvent *events.DesktopSharedDirectoryRead) *events.DesktopSharedDirectoryRead { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryWriteFailureCode return baseEvent @@ -377,10 +327,10 @@ func TestDesktopSharedDirectoryReadEvent(t *testing.T) { }, { // should never happen but just in case - name: "directory name unknown", - sendsSda: false, - sendsReq: true, - errCode: tdp.ErrCodeNil, + name: "directory name unknown", + sendsAnnounce: false, + sendsReq: true, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryRead) *events.DesktopSharedDirectoryRead { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryReadFailureCode baseEvent.DirectoryName = "unknown" @@ -389,10 +339,10 @@ func TestDesktopSharedDirectoryReadEvent(t *testing.T) { }, { // should never happen but just in case - name: "request info unknown", - sendsSda: true, - sendsReq: false, - errCode: tdp.ErrCodeNil, + name: "request info unknown", + sendsAnnounce: true, + sendsReq: false, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryRead) *events.DesktopSharedDirectoryRead { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryReadFailureCode @@ -410,10 +360,10 @@ func TestDesktopSharedDirectoryReadEvent(t *testing.T) { }, { // should never happen but just in case - name: "directory name and request info unknown", - sendsSda: false, - sendsReq: false, - errCode: tdp.ErrCodeNil, + name: "directory name and request info unknown", + sendsAnnounce: false, + sendsReq: false, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryRead) *events.DesktopSharedDirectoryRead { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryReadFailureCode @@ -430,75 +380,59 @@ func TestDesktopSharedDirectoryReadEvent(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - s, id, emitter := setup() - recvHandler := s.makeTDPReceiveHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, &tdp.Conn{}) - sendHandler := s.makeTDPSendHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, &tdp.Conn{}) - if test.sendsSda { - // SharedDirectoryAnnounce initializes the nameCache. - sda := tdp.SharedDirectoryAnnounce{ - DirectoryID: did, + id, audit := setup(testDesktop) + + if test.sendsAnnounce { + // SharedDirectoryAnnounce initializes the name cache + audit.onSharedDirectoryAnnounce(tdp.SharedDirectoryAnnounce{ + DirectoryID: uint32(testDirectoryID), Name: testDirName, - } - recvHandler(sda) + }) } if test.sendsReq { // SharedDirectoryReadRequest initializes the readRequestCache. - req := tdp.SharedDirectoryReadRequest{ - CompletionID: cid, - DirectoryID: did, - Path: path, - Offset: offset, - Length: length, - } - encoded, err := req.Encode() - require.NoError(t, err) - sendHandler(req, encoded) + audit.onSharedDirectoryReadRequest(tdp.SharedDirectoryReadRequest{ + CompletionID: uint32(testCompletionID), + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Offset: testOffset, + Length: testLength, + }) } // SharedDirectoryReadResponse causes the event to be emitted. - res := tdp.SharedDirectoryReadResponse{ - CompletionID: cid, + readEvent := audit.makeSharedDirectoryReadResponse(tdp.SharedDirectoryReadResponse{ + CompletionID: uint32(testCompletionID), ErrCode: test.errCode, - ReadDataLength: length, + ReadDataLength: testLength, ReadData: []byte{}, // irrelevant in this context - } - recvHandler(res) - - event := emitter.LastEvent() - require.NotNil(t, event) - - readEvent, ok := event.(*events.DesktopSharedDirectoryRead) - require.True(t, ok) + }) baseEvent := &events.DesktopSharedDirectoryRead{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryReadEvent, Code: libevents.DesktopSharedDirectoryReadCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: audit.clusterName, + Time: audit.clock.Now().UTC(), }, UserMetadata: id.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, + SessionID: audit.sessionID, WithMFA: id.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + RemoteAddr: audit.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: statusFromErrCode(test.errCode), - DesktopAddr: desktopAddr, + DesktopAddr: audit.desktop.GetAddr(), DirectoryName: testDirName, - DirectoryID: did, - Path: path, - Length: length, - Offset: offset, + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Length: testLength, + Offset: testOffset, } require.Empty(t, cmp.Diff(test.expected(baseEvent), readEvent)) @@ -507,19 +441,10 @@ func TestDesktopSharedDirectoryReadEvent(t *testing.T) { } func TestDesktopSharedDirectoryWriteEvent(t *testing.T) { - sid := "session-0" - desktopAddr := "windows.example.com" - testDirName := "test-dir" - path := "test/path/test-file.txt" - var did uint32 = 2 - var cid uint32 = 999 - var offset uint64 = 500 - var length uint32 = 1000 - for _, test := range []struct { name string - // sendsSda determines whether a SharedDirectoryAnnounce is sent. - sendsSda bool + // sendsAnnounce determines whether a SharedDirectoryAnnounce is sent. + sendsAnnounce bool // sendsReq determines whether a SharedDirectoryWriteRequest is sent. sendsReq bool // errCode is the error code in the simulated SharedDirectoryWriteResponse @@ -530,20 +455,20 @@ func TestDesktopSharedDirectoryWriteEvent(t *testing.T) { }{ { // when everything is working as expected - name: "typical operation", - sendsSda: true, - sendsReq: true, - errCode: tdp.ErrCodeNil, + name: "typical operation", + sendsAnnounce: true, + sendsReq: true, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryWrite) *events.DesktopSharedDirectoryWrite { return baseEvent }, }, { // the Write operation failed - name: "write failed", - sendsSda: true, - sendsReq: true, - errCode: tdp.ErrCodeFailed, + name: "write failed", + sendsAnnounce: true, + sendsReq: true, + errCode: tdp.ErrCodeFailed, expected: func(baseEvent *events.DesktopSharedDirectoryWrite) *events.DesktopSharedDirectoryWrite { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryWriteFailureCode return baseEvent @@ -551,10 +476,10 @@ func TestDesktopSharedDirectoryWriteEvent(t *testing.T) { }, { // should never happen but just in case - name: "directory name unknown", - sendsSda: false, - sendsReq: true, - errCode: tdp.ErrCodeNil, + name: "directory name unknown", + sendsAnnounce: false, + sendsReq: true, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryWrite) *events.DesktopSharedDirectoryWrite { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryWriteFailureCode baseEvent.DirectoryName = "unknown" @@ -563,10 +488,10 @@ func TestDesktopSharedDirectoryWriteEvent(t *testing.T) { }, { // should never happen but just in case - name: "request info unknown", - sendsSda: true, - sendsReq: false, - errCode: tdp.ErrCodeNil, + name: "request info unknown", + sendsAnnounce: true, + sendsReq: false, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryWrite) *events.DesktopSharedDirectoryWrite { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryWriteFailureCode @@ -584,10 +509,10 @@ func TestDesktopSharedDirectoryWriteEvent(t *testing.T) { }, { // should never happen but just in case - name: "directory name and request info unknown", - sendsSda: false, - sendsReq: false, - errCode: tdp.ErrCodeNil, + name: "directory name and request info unknown", + sendsAnnounce: false, + sendsReq: false, + errCode: tdp.ErrCodeNil, expected: func(baseEvent *events.DesktopSharedDirectoryWrite) *events.DesktopSharedDirectoryWrite { baseEvent.Metadata.Code = libevents.DesktopSharedDirectoryWriteFailureCode @@ -604,74 +529,58 @@ func TestDesktopSharedDirectoryWriteEvent(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - s, id, emitter := setup() - recvHandler := s.makeTDPReceiveHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, &tdp.Conn{}) - sendHandler := s.makeTDPSendHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, &tdp.Conn{}) - if test.sendsSda { + id, audit := setup(testDesktop) + + if test.sendsAnnounce { // SharedDirectoryAnnounce initializes the nameCache. - sda := tdp.SharedDirectoryAnnounce{ - DirectoryID: did, + audit.onSharedDirectoryAnnounce(tdp.SharedDirectoryAnnounce{ + DirectoryID: uint32(testDirectoryID), Name: testDirName, - } - recvHandler(sda) + }) } if test.sendsReq { // SharedDirectoryWriteRequest initializes the writeRequestCache. - req := tdp.SharedDirectoryWriteRequest{ - CompletionID: cid, - DirectoryID: did, - Path: path, - Offset: offset, - WriteDataLength: length, - } - encoded, err := req.Encode() - require.NoError(t, err) - sendHandler(req, encoded) + audit.onSharedDirectoryWriteRequest(tdp.SharedDirectoryWriteRequest{ + CompletionID: uint32(testCompletionID), + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Offset: testOffset, + WriteDataLength: testLength, + }) } // SharedDirectoryWriteResponse causes the event to be emitted. - res := tdp.SharedDirectoryWriteResponse{ - CompletionID: cid, + writeEvent := audit.makeSharedDirectoryWriteResponse(tdp.SharedDirectoryWriteResponse{ + CompletionID: uint32(testCompletionID), ErrCode: test.errCode, - BytesWritten: length, - } - recvHandler(res) - - event := emitter.LastEvent() - require.NotNil(t, event) - - writeEvent, ok := event.(*events.DesktopSharedDirectoryWrite) - require.True(t, ok) + BytesWritten: testLength, + }) baseEvent := &events.DesktopSharedDirectoryWrite{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryWriteEvent, Code: libevents.DesktopSharedDirectoryWriteCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: audit.clusterName, + Time: audit.clock.Now().UTC(), }, UserMetadata: id.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, + SessionID: audit.sessionID, WithMFA: id.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + RemoteAddr: audit.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: statusFromErrCode(test.errCode), - DesktopAddr: desktopAddr, + DesktopAddr: audit.desktop.GetAddr(), DirectoryName: testDirName, - DirectoryID: did, - Path: path, - Length: length, - Offset: offset, + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Length: testLength, + Offset: testOffset, } require.Empty(t, cmp.Diff(test.expected(baseEvent), writeEvent)) @@ -679,272 +588,249 @@ func TestDesktopSharedDirectoryWriteEvent(t *testing.T) { } } -// fillEntry is a helper function that fills an entry's readRequestCache up with entryMaxItems. -func fillEntry(entry *sharedDirectoryAuditCacheEntry, did directoryID) { - for i := 0; i < entryMaxItems; i++ { - entry.readRequestCache[completionID(i)] = readRequestInfo{ +// fillReadRequestCache is a helper function that fills an entry's readRequestCache up with entryMaxItems. +func fillReadRequestCache(cache *sharedDirectoryAuditCache, did directoryID) { + cache.Lock() + defer cache.Unlock() + + for i := 0; i < maxAuditCacheItems; i++ { + cache.readRequestCache[completionID(i)] = readRequestInfo{ directoryID: did, } } } // TestDesktopSharedDirectoryStartEventAuditCacheMax tests that a -// failed DesktopSharedDirectoryStart is emitted and the tdpConn is -// closed when we receive a SharedDirectoryAnnounce whose corresponding -// sharedDirectoryAuditCacheEntry is full. +// failed DesktopSharedDirectoryStart is emitted when the shared +// directory audit cache is full. func TestDesktopSharedDirectoryStartEventAuditCacheMax(t *testing.T) { - sid := "session-0" - desktopAddr := "windows.example.com" - testDirName := "test-dir" - var did uint32 = 2 - - s, id, emitter := setup() - testConn := &testConn{} - tdpConn := tdp.NewConn(testConn) - recvHandler := s.makeTDPReceiveHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, tdpConn) + + id, audit := setup(testDesktop) // Set the audit cache entry to the maximum allowable size - entry := newSharedDirectoryAuditCacheEntry() - fillEntry(entry, directoryID(did)) - s.auditCache.m[sessionID(sid)] = entry + fillReadRequestCache(&audit.auditCache, testDirectoryID) // Send a SharedDirectoryAnnounce - sda := tdp.SharedDirectoryAnnounce{ - DirectoryID: did, + startEvent := audit.onSharedDirectoryAnnounce(tdp.SharedDirectoryAnnounce{ + DirectoryID: uint32(testDirectoryID), Name: testDirName, - } - recvHandler(sda) + }) + require.NotNil(t, startEvent) // Expect the audit cache to emit a failed DesktopSharedDirectoryStart // with a status detailing the security problem. - event := emitter.LastEvent() - require.NotNil(t, event) - startEvent, ok := event.(*events.DesktopSharedDirectoryStart) - require.True(t, ok) - expected := &events.DesktopSharedDirectoryStart{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryStartEvent, Code: libevents.DesktopSharedDirectoryStartFailureCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: audit.clusterName, + Time: audit.clock.Now().UTC(), }, UserMetadata: id.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, + SessionID: audit.sessionID, WithMFA: id.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + RemoteAddr: audit.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: events.Status{ Success: false, - Error: fmt.Sprintf("audit cache for sessionID(%v) exceeded maximum size", sid), + Error: "audit cache exceeded maximum size", UserMessage: "Teleport failed the request and terminated the session as a security precaution", }, - DesktopAddr: desktopAddr, + DesktopAddr: audit.desktop.GetAddr(), DirectoryName: testDirName, - DirectoryID: did, + DirectoryID: uint32(testDirectoryID), } require.Empty(t, cmp.Diff(expected, startEvent)) - - // Check that Close was called on the TDP connection - require.True(t, testConn.closeCalled) } // TestDesktopSharedDirectoryReadEventAuditCacheMax tests that a -// failed DesktopSharedDirectoryRead is emitted and the tdpConn is -// closed when we receive a SharedDirectoryReadRequest whose corresponding -// sharedDirectoryAuditCacheEntry is full. +// failed DesktopSharedDirectoryRead is generated when the shared +// directory audit cache is full. func TestDesktopSharedDirectoryReadEventAuditCacheMax(t *testing.T) { - sid := "session-0" - desktopAddr := "windows.example.com" - testDirName := "test-dir" - path := "test/path/test-file.txt" - var did uint32 = 2 - var cid uint32 = 999 - var offset uint64 = 500 - var length uint32 = 1000 - - s, id, emitter := setup() - testConn := &testConn{} - tdpConn := tdp.NewConn(testConn) - recvHandler := s.makeTDPReceiveHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, tdpConn) - sendHandler := s.makeTDPSendHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, tdpConn) + + id, audit := setup(testDesktop) // Send a SharedDirectoryAnnounce - sda := tdp.SharedDirectoryAnnounce{ - DirectoryID: did, + audit.onSharedDirectoryAnnounce(tdp.SharedDirectoryAnnounce{ + DirectoryID: uint32(testDirectoryID), Name: testDirName, - } - recvHandler(sda) + }) // Set the audit cache entry to the maximum allowable size - entry, ok := s.auditCache.m[sessionID(sid)] - require.True(t, ok) - fillEntry(entry, directoryID(did)) + fillReadRequestCache(&audit.auditCache, testDirectoryID) // SharedDirectoryReadRequest should cause a failed audit event. - req := tdp.SharedDirectoryReadRequest{ - CompletionID: cid, - DirectoryID: did, - Path: path, - Offset: offset, - Length: length, - } - encoded, err := req.Encode() - require.NoError(t, err) - sendHandler(req, encoded) - - // Expect the audit cache to emit a failed DesktopSharedDirectoryRead - // with a status detailing the security problem. - event := emitter.LastEvent() - require.NotNil(t, event) - readEvent, ok := event.(*events.DesktopSharedDirectoryRead) - require.True(t, ok) + readEvent := audit.onSharedDirectoryReadRequest(tdp.SharedDirectoryReadRequest{ + CompletionID: uint32(testCompletionID), + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Offset: testOffset, + Length: testLength, + }) + require.NotNil(t, readEvent) expected := &events.DesktopSharedDirectoryRead{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryReadEvent, Code: libevents.DesktopSharedDirectoryReadFailureCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: audit.clusterName, + Time: audit.clock.Now().UTC(), }, UserMetadata: id.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, + SessionID: audit.sessionID, WithMFA: id.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + RemoteAddr: audit.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: events.Status{ Success: false, - Error: fmt.Sprintf("audit cache for sessionID(%v) exceeded maximum size", sid), + Error: "audit cache exceeded maximum size", UserMessage: "Teleport failed the request and terminated the session as a security precaution", }, - DesktopAddr: desktopAddr, + DesktopAddr: audit.desktop.GetAddr(), DirectoryName: testDirName, - DirectoryID: did, - Path: path, - Length: length, - Offset: offset, + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Length: testLength, + Offset: testOffset, } require.Empty(t, cmp.Diff(expected, readEvent)) - - // Check that Close was called on the TDP connection - require.True(t, testConn.closeCalled) } // TestDesktopSharedDirectoryWriteEventAuditCacheMax tests that a -// failed DesktopSharedDirectoryWrite is emitted and the tdpConn is -// closed when we receive a SharedDirectoryWriteRequest whose corresponding -// sharedDirectoryAuditCacheEntry is full. +// failed DesktopSharedDirectoryWrite is generated when the shared +// directory audit cache is full. func TestDesktopSharedDirectoryWriteEventAuditCacheMax(t *testing.T) { - sid := "session-0" - desktopAddr := "windows.example.com" - testDirName := "test-dir" - path := "test/path/test-file.txt" - var did uint32 = 2 - var cid uint32 = 999 - var offset uint64 = 500 - var length uint32 = 1000 - - s, id, emitter := setup() - testConn := &testConn{} - tdpConn := tdp.NewConn(testConn) - recvHandler := s.makeTDPReceiveHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, tdpConn) - sendHandler := s.makeTDPSendHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, tdpConn) - // Send a SharedDirectoryAnnounce - sda := tdp.SharedDirectoryAnnounce{ - DirectoryID: did, + id, audit := setup(testDesktop) + + audit.onSharedDirectoryAnnounce(tdp.SharedDirectoryAnnounce{ + DirectoryID: uint32(testDirectoryID), Name: testDirName, - } - recvHandler(sda) + }) - // Set the audit cache entry to the maximum allowable size - entry, ok := s.auditCache.m[sessionID(sid)] - require.True(t, ok) - fillEntry(entry, directoryID(did)) - - // SharedDirectoryWriteRequest should cause a failed audit event. - req := tdp.SharedDirectoryWriteRequest{ - CompletionID: cid, - DirectoryID: did, - Path: path, - Offset: offset, - WriteDataLength: length, - } - encoded, err := req.Encode() - require.NoError(t, err) - sendHandler(req, encoded) + fillReadRequestCache(&audit.auditCache, testDirectoryID) - // Expect the audit cache to emit a failed DesktopSharedDirectoryWrite - // with a status detailing the security problem. - event := emitter.LastEvent() - require.NotNil(t, event) - writeEvent, ok := event.(*events.DesktopSharedDirectoryWrite) - require.True(t, ok) + writeEvent := audit.onSharedDirectoryWriteRequest(tdp.SharedDirectoryWriteRequest{ + CompletionID: uint32(testCompletionID), + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Offset: testOffset, + WriteDataLength: testLength, + }) + require.NotNil(t, writeEvent, "audit event should have been generated") expected := &events.DesktopSharedDirectoryWrite{ Metadata: events.Metadata{ Type: libevents.DesktopSharedDirectoryWriteEvent, Code: libevents.DesktopSharedDirectoryWriteFailureCode, - ClusterName: s.clusterName, - Time: s.cfg.Clock.Now().UTC(), + ClusterName: audit.clusterName, + Time: audit.clock.Now().UTC(), }, UserMetadata: id.GetUserMetadata(), SessionMetadata: events.SessionMetadata{ - SessionID: sid, + SessionID: audit.sessionID, WithMFA: id.MFAVerified, }, ConnectionMetadata: events.ConnectionMetadata{ LocalAddr: id.LoginIP, - RemoteAddr: desktopAddr, + RemoteAddr: audit.desktop.GetAddr(), Protocol: libevents.EventProtocolTDP, }, Status: events.Status{ Success: false, - Error: fmt.Sprintf("audit cache for sessionID(%v) exceeded maximum size", sid), + Error: "audit cache exceeded maximum size", UserMessage: "Teleport failed the request and terminated the session as a security precaution", }, - DesktopAddr: desktopAddr, + DesktopAddr: audit.desktop.GetAddr(), DirectoryName: testDirName, - DirectoryID: did, - Path: path, - Length: length, - Offset: offset, + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Length: testLength, + Offset: testOffset, } require.Empty(t, cmp.Diff(expected, writeEvent)) - - // Check that Close was called on the TDP connection - require.True(t, testConn.closeCalled) } -type testConn struct { - *bytes.Buffer - closeCalled bool -} +// TestAuditCacheLifecycle confirms that the audit cache operates correctly +// in response to protocol events. +func TestAuditCacheLifecycle(t *testing.T) { + _, audit := setup(testDesktop) + + // SharedDirectoryAnnounce initializes the nameCache. + audit.onSharedDirectoryAnnounce(tdp.SharedDirectoryAnnounce{ + DirectoryID: uint32(testDirectoryID), + Name: testDirName, + }) -func (t *testConn) Close() error { - t.closeCalled = true - return nil + // Confirm that audit cache is in the expected state. + require.Equal(t, 1, audit.auditCache.totalItems()) + name, ok := audit.auditCache.GetName(testDirectoryID) + require.True(t, ok) + require.Equal(t, directoryName(testDirName), name) + _, ok = audit.auditCache.TakeReadRequestInfo(testCompletionID) + require.False(t, ok) + _, ok = audit.auditCache.TakeWriteRequestInfo(testCompletionID) + require.False(t, ok) + + // A SharedDirectoryReadRequest should add a corresponding entry in the readRequestCache. + audit.onSharedDirectoryReadRequest(tdp.SharedDirectoryReadRequest{ + CompletionID: uint32(testCompletionID), + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Offset: testOffset, + Length: testLength, + }) + require.Equal(t, 2, audit.auditCache.totalItems()) + + // A SharedDirectoryWriteRequest should add a corresponding entry in the writeRequestCache. + audit.onSharedDirectoryWriteRequest(tdp.SharedDirectoryWriteRequest{ + CompletionID: uint32(testCompletionID), + DirectoryID: uint32(testDirectoryID), + Path: testFilePath, + Offset: testOffset, + WriteDataLength: testLength, + }) + require.Equal(t, 3, audit.auditCache.totalItems()) + + // Check that the readRequestCache was properly filled out. + require.Contains(t, audit.auditCache.readRequestCache, testCompletionID) + + // Check that the writeRequestCache was properly filled out. + require.Contains(t, audit.auditCache.writeRequestCache, testCompletionID) + + // SharedDirectoryReadResponse should cause the entry in the readRequestCache to be cleaned up. + audit.makeSharedDirectoryReadResponse(tdp.SharedDirectoryReadResponse{ + CompletionID: uint32(testCompletionID), + ErrCode: tdp.ErrCodeNil, + ReadDataLength: testLength, + ReadData: []byte{}, // irrelevant in this context + }) + require.Equal(t, 2, audit.auditCache.totalItems()) + + // SharedDirectoryWriteResponse should cause the entry in the writeRequestCache to be cleaned up. + audit.makeSharedDirectoryWriteResponse(tdp.SharedDirectoryWriteResponse{ + CompletionID: uint32(testCompletionID), + ErrCode: tdp.ErrCodeNil, + BytesWritten: testLength, + }) + require.Equal(t, 1, audit.auditCache.totalItems()) + + // Check that the readRequestCache was properly cleaned up. + require.NotContains(t, audit.auditCache.readRequestCache, testCompletionID) + + // Check that the writeRequestCache was properly cleaned up. + require.NotContains(t, audit.auditCache.writeRequestCache, testCompletionID) } diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index b993bbaf6f378..6bb0464ba7c23 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -188,10 +188,13 @@ func (c *Client) Run(ctx context.Context) error { // Both goroutines have finished, it's now // safe for the deferred c.cleanup() call to // clean up the memory. - return nil } +func (c *Client) GetClientUsername() string { + return c.username +} + func (c *Client) readClientUsername() error { for { msg, err := c.cfg.Conn.ReadMessage() diff --git a/lib/srv/desktop/rdp/rdpclient/client_nop.go b/lib/srv/desktop/rdp/rdpclient/client_nop.go index 221da22a87ff0..b836c9e08bf4c 100644 --- a/lib/srv/desktop/rdp/rdpclient/client_nop.go +++ b/lib/srv/desktop/rdp/rdpclient/client_nop.go @@ -44,6 +44,10 @@ func (c *Client) Run(ctx context.Context) error { return errors.New("the real rdpclient.Client implementation was not included in this build") } +func (c *Client) GetClientUsername() string { + return "" +} + // GetClientLastActive returns the time of the last recorded activity. func (c *Client) GetClientLastActive() time.Time { return time.Now().UTC() diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index 15d64014fc55c..be1dc6739eeca 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -839,9 +839,7 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, log.Infof("desktop session %v will not be recorded, user %v's roles disable recording", string(sessionID), authCtx.User.GetName()) } - var windowsUser string authorize := func(login string) error { - windowsUser = login // capture attempted login user state := authCtx.GetAccessState(authPref) return authCtx.Checker.CheckAccess( desktop, @@ -856,12 +854,6 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, ctx, cancel := context.WithCancel(ctx) defer cancel() - // Create a session tracker so that other services, such as - // the session upload completer, can track the session's lifetime. - if err := s.trackSession(ctx, &identity, windowsUser, string(sessionID), desktop); err != nil { - return trace.Wrap(err) - } - sw, err := s.newStreamWriter(recordSession, string(sessionID)) if err != nil { return trace.Wrap(err) @@ -879,17 +871,23 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, }() }() - delay := timer() - tdpConn.OnSend = s.makeTDPSendHandler(ctx, sw, delay, &identity, string(sessionID), desktop.GetAddr(), tdpConn) - tdpConn.OnRecv = s.makeTDPReceiveHandler(ctx, sw, delay, &identity, string(sessionID), desktop.GetAddr(), tdpConn) + // We won't have the windows username until we start to read from the websocket, + // but we need to start emitting audit events now. Create an auditor without + // specifying the username (we'll update it soon as we have it). + audit := s.newSessionAuditor(string(sessionID), &identity, "", desktop) - sessionStartTime := s.cfg.Clock.Now().UTC().Round(time.Millisecond) groups, err := authCtx.Checker.DesktopGroups(desktop) if err != nil && !trace.IsAccessDenied(err) { - s.onSessionStart(ctx, sw, &identity, sessionStartTime, windowsUser, string(sessionID), desktop, err) + startEvent := audit.makeSessionStart(err) + s.emit(ctx, sw, startEvent) return trace.Wrap(err) } createUsers := err == nil + + delay := timer() + tdpConn.OnSend = s.makeTDPSendHandler(ctx, sw, delay, tdpConn, audit) + tdpConn.OnRecv = s.makeTDPReceiveHandler(ctx, sw, delay, tdpConn, audit) + rdpc, err := rdpclient.New(rdpclient.Config{ Log: log, GenerateUserCert: func(ctx context.Context, username string, ttl time.Duration) (certDER, keyDER []byte, err error) { @@ -903,8 +901,22 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, AllowDirectorySharing: authCtx.Checker.DesktopDirectorySharing(), ShowDesktopWallpaper: s.cfg.ShowDesktopWallpaper, }) + // before we check the error above, we grab the windows user so that + // future audit events include the proper username + var windowsUser string + if rdpc != nil { + windowsUser = rdpc.GetClientUsername() + audit.windowsUser = windowsUser + } if err != nil { - s.onSessionStart(ctx, sw, &identity, sessionStartTime, windowsUser, string(sessionID), desktop, err) + startEvent := audit.makeSessionStart(err) + s.emit(ctx, sw, startEvent) + return trace.Wrap(err) + } + + // Create a session tracker so that other services, such as + // the session upload completer, can track the session's lifetime. + if err := s.trackSession(ctx, &identity, windowsUser, string(sessionID), desktop); err != nil { return trace.Wrap(err) } @@ -938,19 +950,30 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, // if we can't establish a connection monitor then we can't enforce RBAC. // consider this a connection failure and return an error // (in the happy path, rdpc remains open until Wait() completes) - s.onSessionStart(ctx, sw, &identity, sessionStartTime, windowsUser, string(sessionID), desktop, err) + startEvent := audit.makeSessionStart(err) + s.emit(ctx, sw, startEvent) return trace.Wrap(err) } - s.onSessionStart(ctx, sw, &identity, sessionStartTime, windowsUser, string(sessionID), desktop, nil) + startEvent := audit.makeSessionStart(nil) + s.emit(ctx, sw, startEvent) + err = rdpc.Run(ctx) - s.onSessionEnd(ctx, sw, &identity, sessionStartTime, recordSession, windowsUser, string(sessionID), desktop) + + // ctx may have been canceled, so emit with a separate context + endEvent := audit.makeSessionEnd(recordSession) + s.emit(context.Background(), sw, endEvent) return trace.Wrap(err) } -func (s *WindowsService) makeTDPSendHandler(ctx context.Context, emitter events.Emitter, delay func() int64, - id *tlsca.Identity, sessionID, desktopAddr string, tdpConn *tdp.Conn) func(m tdp.Message, b []byte) { +func (s *WindowsService) makeTDPSendHandler( + ctx context.Context, + emitter events.Emitter, + delay func() int64, + tdpConn *tdp.Conn, + audit *desktopSessionAuditor, +) func(m tdp.Message, b []byte) { return func(m tdp.Message, b []byte) { switch b[0] { case byte(tdp.TypePNG2Frame), byte(tdp.TypePNGFrame), byte(tdp.TypeError), byte(tdp.TypeNotification): @@ -977,26 +1000,48 @@ func (s *WindowsService) makeTDPSendHandler(ctx context.Context, emitter events. // the TDP send handler emits a clipboard receive event, because we // received clipboard data from the remote desktop and are sending // it on the TDP connection - s.onClipboardReceive(ctx, emitter, id, sessionID, desktopAddr, int32(len(clip))) + rxEvent := audit.makeClipboardReceive(int32(len(clip))) + s.emit(ctx, emitter, rxEvent) } case byte(tdp.TypeSharedDirectoryAcknowledge): if message, ok := m.(tdp.SharedDirectoryAcknowledge); ok { - s.onSharedDirectoryAcknowledge(ctx, emitter, id, sessionID, desktopAddr, message) + s.emit(ctx, emitter, audit.makeSharedDirectoryStart(message)) } case byte(tdp.TypeSharedDirectoryReadRequest): if message, ok := m.(tdp.SharedDirectoryReadRequest); ok { - s.onSharedDirectoryReadRequest(ctx, emitter, id, sessionID, desktopAddr, message, tdpConn) + errorEvent := audit.onSharedDirectoryReadRequest(message) + if errorEvent != nil { + // if we can't audit due to a full cache, abort the connection + // as a security measure + if err := tdpConn.Close(); err != nil { + s.cfg.Log.WithError(err).Errorf("error when terminating sessionID(%v) for audit cache maximum size violation", audit.sessionID) + } + s.emit(ctx, emitter, errorEvent) + } } case byte(tdp.TypeSharedDirectoryWriteRequest): if message, ok := m.(tdp.SharedDirectoryWriteRequest); ok { - s.onSharedDirectoryWriteRequest(ctx, emitter, id, sessionID, desktopAddr, message, tdpConn) + errorEvent := audit.onSharedDirectoryWriteRequest(message) + if errorEvent != nil { + // if we can't audit due to a full cache, abort the connection + // as a security measure + if err := tdpConn.Close(); err != nil { + s.cfg.Log.WithError(err).Errorf("error when terminating sessionID(%v) for audit cache maximum size violation", audit.sessionID) + } + s.emit(ctx, emitter, errorEvent) + } } } } } -func (s *WindowsService) makeTDPReceiveHandler(ctx context.Context, emitter events.Emitter, delay func() int64, - id *tlsca.Identity, sessionID, desktopAddr string, tdpConn *tdp.Conn) func(m tdp.Message) { +func (s *WindowsService) makeTDPReceiveHandler( + ctx context.Context, + emitter events.Emitter, + delay func() int64, + tdpConn *tdp.Conn, + audit *desktopSessionAuditor, +) func(m tdp.Message) { return func(m tdp.Message) { switch msg := m.(type) { case tdp.ClientScreenSpec, tdp.MouseButton, tdp.MouseMove: @@ -1023,13 +1068,22 @@ func (s *WindowsService) makeTDPReceiveHandler(ctx context.Context, emitter even // the TDP receive handler emits a clipboard send event, because we // received clipboard data from the user (over TDP) and are sending // it to the remote desktop - s.onClipboardSend(ctx, emitter, id, sessionID, desktopAddr, int32(len(msg))) + sendEvent := audit.makeClipboardSend(int32(len(msg))) + s.emit(ctx, emitter, sendEvent) case tdp.SharedDirectoryAnnounce: - s.onSharedDirectoryAnnounce(ctx, emitter, id, sessionID, desktopAddr, m.(tdp.SharedDirectoryAnnounce), tdpConn) + errorEvent := audit.onSharedDirectoryAnnounce(m.(tdp.SharedDirectoryAnnounce)) + if errorEvent != nil { + // if we can't audit due to a full cache, abort the connection + // as a security measure + if err := tdpConn.Close(); err != nil { + s.cfg.Log.WithError(err).Errorf("error when terminating sessionID(%v) for audit cache maximum size violation", audit.sessionID) + } + s.emit(ctx, emitter, errorEvent) + } case tdp.SharedDirectoryReadResponse: - s.onSharedDirectoryReadResponse(ctx, emitter, id, sessionID, desktopAddr, msg) + s.emit(ctx, emitter, audit.makeSharedDirectoryReadResponse(msg)) case tdp.SharedDirectoryWriteResponse: - s.onSharedDirectoryWriteResponse(ctx, emitter, id, sessionID, desktopAddr, msg) + s.emit(ctx, emitter, audit.makeSharedDirectoryWriteResponse(msg)) } } } diff --git a/lib/srv/desktop/windows_server_test.go b/lib/srv/desktop/windows_server_test.go index 10d63ed287128..f8723627ae821 100644 --- a/lib/srv/desktop/windows_server_test.go +++ b/lib/srv/desktop/windows_server_test.go @@ -199,8 +199,7 @@ func TestEmitsRecordingEventsOnSend(t *testing.T) { encoded := []byte{byte(tdp.TypePNGFrame), 0x01, 0x02} delay := func() int64 { return 0 } - handler := s.makeTDPSendHandler(context.Background(), emitter, delay, - nil, "session-1", "windows.example.com", &tdp.Conn{}) + handler := s.makeTDPSendHandler(context.Background(), emitter, delay, nil, nil) // the handler accepts both the message structure and its encoded form, // but our logic only depends on the encoded form, so pass a nil message @@ -230,8 +229,7 @@ func TestSkipsExtremelyLargePNGs(t *testing.T) { maliciousPNG[0] = byte(tdp.TypePNGFrame) delay := func() int64 { return 0 } - handler := s.makeTDPSendHandler(context.Background(), emitter, delay, - nil, "session-1", "windows.example.com", &tdp.Conn{}) + handler := s.makeTDPSendHandler(context.Background(), emitter, delay, nil, nil) // the handler accepts both the message structure and its encoded form, // but our logic only depends on the encoded form, so pass a nil message @@ -251,8 +249,7 @@ func TestEmitsRecordingEventsOnReceive(t *testing.T) { emitter := &eventstest.MockEmitter{} delay := func() int64 { return 0 } - handler := s.makeTDPReceiveHandler(context.Background(), emitter, delay, - nil, "session-1", "windows.example.com", &tdp.Conn{}) + handler := s.makeTDPReceiveHandler(context.Background(), emitter, delay, nil, nil) msg := tdp.MouseButton{ Button: tdp.LeftMouseButton, @@ -270,10 +267,22 @@ func TestEmitsRecordingEventsOnReceive(t *testing.T) { } func TestEmitsClipboardSendEvents(t *testing.T) { - s, id, emitter := setup() - handler := s.makeTDPReceiveHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, "session-0", "windows.example.com", &tdp.Conn{}) + _, audit := setup(testDesktop) + emitter := &eventstest.MockEmitter{} + s := &WindowsService{ + cfg: WindowsServiceConfig{ + Clock: audit.clock, + Emitter: emitter, + }, + } + + handler := s.makeTDPReceiveHandler( + context.Background(), + emitter, + func() int64 { return 0 }, + &tdp.Conn{}, + audit, + ) fakeClipboardData := make([]byte, 1024) rand.Read(fakeClipboardData) @@ -287,17 +296,29 @@ func TestEmitsClipboardSendEvents(t *testing.T) { cs, ok := e.(*events.DesktopClipboardSend) require.True(t, ok) require.Equal(t, int32(len(fakeClipboardData)), cs.Length) - require.Equal(t, "session-0", cs.SessionID) - require.Equal(t, "windows.example.com", cs.DesktopAddr) - require.Equal(t, s.clusterName, cs.ClusterName) + require.Equal(t, audit.sessionID, cs.SessionID) + require.Equal(t, audit.desktop.GetAddr(), cs.DesktopAddr) + require.Equal(t, audit.clusterName, cs.ClusterName) require.Equal(t, start, cs.Time) } func TestEmitsClipboardReceiveEvents(t *testing.T) { - s, id, emitter := setup() - handler := s.makeTDPSendHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, "session-0", "windows.example.com", &tdp.Conn{}) + _, audit := setup(testDesktop) + emitter := &eventstest.MockEmitter{} + s := &WindowsService{ + cfg: WindowsServiceConfig{ + Clock: audit.clock, + Emitter: emitter, + }, + } + + handler := s.makeTDPSendHandler( + context.Background(), + emitter, + func() int64 { return 0 }, + &tdp.Conn{}, + audit, + ) fakeClipboardData := make([]byte, 512) rand.Read(fakeClipboardData) @@ -313,133 +334,8 @@ func TestEmitsClipboardReceiveEvents(t *testing.T) { cs, ok := e.(*events.DesktopClipboardReceive) require.True(t, ok) require.Equal(t, int32(len(fakeClipboardData)), cs.Length) - require.Equal(t, "session-0", cs.SessionID) - require.Equal(t, "windows.example.com", cs.DesktopAddr) - require.Equal(t, s.clusterName, cs.ClusterName) + require.Equal(t, audit.sessionID, cs.SessionID) + require.Equal(t, audit.desktop.GetAddr(), cs.DesktopAddr) + require.Equal(t, audit.clusterName, cs.ClusterName) require.Equal(t, start, cs.Time) } - -// TestAuditCacheLifecycle confirms that the audit cache is properly -// initialized upon receipt of a tdp.SharedDirectoryAnnounce message, -// and properly cleaned up upon session end. -func TestAuditCacheLifecycle(t *testing.T) { - s, id, emitter := setup() - sid := "session-0" - desktopAddr := "windows.example.com" - testDirName := "test-dir" - path := "test/path/test-file.txt" - var did uint32 = 2 - var cid uint32 = 999 - var offset uint64 = 500 - var length uint32 = 1000 - recvHandler := s.makeTDPReceiveHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, &tdp.Conn{}) - sendHandler := s.makeTDPSendHandler(context.Background(), - emitter, func() int64 { return 0 }, - id, sid, desktopAddr, &tdp.Conn{}) - - // SharedDirectoryAnnounce initializes the nameCache. - msg := tdp.SharedDirectoryAnnounce{ - DirectoryID: 2, - Name: testDirName, - } - recvHandler(msg) - - // Check than an initialized audit cache entry is created - // for sessionID upon receipt of a tdp.SharedDirectoryAnnounce. - entry, ok := s.auditCache.m[sessionID(sid)] - require.True(t, ok) - require.NotNil(t, entry.nameCache) - require.NotNil(t, entry.readRequestCache) - require.NotNil(t, entry.writeRequestCache) - - // Confirm that audit cache entry for sid - // is in the expected state. - require.Equal(t, 1, entry.totalItems()) - name, ok := s.auditCache.GetName(sessionID(sid), directoryID(did)) - require.True(t, ok) - require.Equal(t, directoryName(testDirName), name) - _, ok = s.auditCache.TakeReadRequestInfo(sessionID(sid), completionID(cid)) - require.False(t, ok) - _, ok = s.auditCache.TakeWriteRequestInfo(sessionID(sid), completionID(cid)) - require.False(t, ok) - - // A SharedDirectoryReadRequest should add a corresponding entry in the readRequestCache. - readReq := tdp.SharedDirectoryReadRequest{ - CompletionID: cid, - DirectoryID: did, - Path: path, - Offset: offset, - Length: length, - } - encoded, err := readReq.Encode() - require.NoError(t, err) - sendHandler(readReq, encoded) - require.Equal(t, 2, entry.totalItems()) - - // A SharedDirectoryWriteRequest should add a corresponding entry in the writeRequestCache. - writeReq := tdp.SharedDirectoryWriteRequest{ - CompletionID: cid, - DirectoryID: did, - Path: path, - Offset: offset, - WriteDataLength: length, - } - encoded, err = writeReq.Encode() - require.NoError(t, err) - sendHandler(writeReq, encoded) - require.Equal(t, 3, entry.totalItems()) - - // Check that the readRequestCache was properly filled out. - require.Contains(t, entry.readRequestCache, completionID(cid)) - - // Check that the writeRequestCache was properly filled out. - require.Contains(t, entry.writeRequestCache, completionID(cid)) - - // SharedDirectoryReadResponse should cause the entry in the readRequestCache to be cleaned up. - readRes := tdp.SharedDirectoryReadResponse{ - CompletionID: cid, - ErrCode: tdp.ErrCodeNil, - ReadDataLength: length, - ReadData: []byte{}, // irrelevant in this context - } - recvHandler(readRes) - require.Equal(t, 2, entry.totalItems()) - - // SharedDirectoryWriteResponse should cause the entry in the writeRequestCache to be cleaned up. - writeRes := tdp.SharedDirectoryWriteResponse{ - CompletionID: cid, - ErrCode: tdp.ErrCodeNil, - BytesWritten: length, - } - recvHandler(writeRes) - require.Equal(t, 1, entry.totalItems()) - - // Check that the readRequestCache was properly cleaned up. - require.NotContains(t, entry.readRequestCache, completionID(cid)) - - // Check that the writeRequestCache was properly cleaned up. - require.NotContains(t, entry.writeRequestCache, completionID(cid)) - - // Simulate a session end event, which should clean up the cache for sessionID(sid) entirely. - s.onSessionEnd( - context.Background(), - s.cfg.Emitter, - id, - s.cfg.Clock.Now().UTC().Round(time.Millisecond), - true, - "Administrator", - sid, - &types.WindowsDesktopV3{}, - ) - - // Confirm that the audit cache at sessionID(sid) was cleaned up. - _, ok = s.auditCache.GetName(sessionID(sid), directoryID(did)) - require.False(t, ok) - _, ok = s.auditCache.TakeReadRequestInfo(sessionID(sid), completionID(cid)) - require.False(t, ok) - _, ok = s.auditCache.TakeWriteRequestInfo(sessionID(sid), completionID(cid)) - require.False(t, ok) - require.NotContains(t, s.auditCache.m, sessionID(sid)) -}