diff --git a/lib/auth/auth.go b/lib/auth/auth.go index a388122a2b61b..183aea8fa61c1 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -95,6 +95,7 @@ import ( "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1" "github.com/gravitational/teleport/lib/auth/okta" "github.com/gravitational/teleport/lib/auth/recordingencryption" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" "github.com/gravitational/teleport/lib/auth/summarizer" "github.com/gravitational/teleport/lib/auth/userloginstate" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" @@ -498,6 +499,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (as *Server, err error) { if cfg.SessionSummarizerProvider == nil { cfg.SessionSummarizerProvider = summarizer.NewSessionSummarizerProvider() } + if cfg.RecordingMetadataProvider == nil { + cfg.RecordingMetadataProvider = recordingmetadata.NewProvider() + } if cfg.WorkloadIdentityX509Revocations == nil { cfg.WorkloadIdentityX509Revocations, err = local.NewWorkloadIdentityX509RevocationService(cfg.Backend) if err != nil { @@ -662,6 +666,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (as *Server, err error) { accessMonitoringEnabled: cfg.AccessMonitoringEnabled, logger: cfg.Logger, sessionSummarizerProvider: cfg.SessionSummarizerProvider, + recordingMetadataProvider: cfg.RecordingMetadataProvider, } as.inventory = inventory.NewController(as, services, inventory.WithAuthServerID(cfg.HostUUID), @@ -1343,9 +1348,16 @@ type Server struct { // plugin. The summarizer itself summarizes session recordings. sessionSummarizerProvider *summarizer.SessionSummarizerProvider + // recordingMetadataProvider provides recording metadata for session recordings. + recordingMetadataProvider *recordingmetadata.Provider + // BotInstanceVersionReporter is called periodically to generate a report of // the number of bot instances by version and update group. BotInstanceVersionReporter *machineidv1.AutoUpdateVersionReporter + + // EncryptedIO provides encryption for session related data such as + // recordings, thumbnails, and metadata. + EncryptedIO *recordingencryption.EncryptedIO } // SetSAMLService registers svc as the SAMLService that provides the SAML diff --git a/lib/auth/authtest/authtest.go b/lib/auth/authtest/authtest.go index a7524daf00b92..5e07a5f072f74 100644 --- a/lib/auth/authtest/authtest.go +++ b/lib/auth/authtest/authtest.go @@ -48,6 +48,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/accesspoint" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/summarizer" authority "github.com/gravitational/teleport/lib/auth/testauthority" @@ -96,6 +97,8 @@ type AuthServerConfig struct { // SessionSummarizerProvider allows a test to configure its own session // summarizer provider. SessionSummarizerProvider *summarizer.SessionSummarizerProvider + // RecordingMetadataProvider allows a test to configure its own recording + RecordingMetadataProvider *recordingmetadata.Provider // TraceClient allows a test to configure the trace client TraceClient otlptrace.Client // AuthPreferenceSpec is custom initial AuthPreference spec for the test. @@ -338,6 +341,7 @@ func NewAuthServer(cfg AuthServerConfig) (*AuthServer, error) { KeyStoreConfig: cfg.KeystoreConfig, MultipartHandler: cfg.UploadHandler, SessionSummarizerProvider: cfg.SessionSummarizerProvider, + RecordingMetadataProvider: cfg.RecordingMetadataProvider, }, WithClock(cfg.Clock), // Reduce auth.Server bcrypt costs when testing. diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 0f3ce31b677f5..89b8950d0990a 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5170,7 +5170,6 @@ func (g *GRPCServer) rangeDefaultInstallers(ctx context.Context, start, end stri return } } - } } @@ -5190,7 +5189,6 @@ func (g *GRPCServer) ListInstallers(ctx context.Context, req *authpb.ListInstall int(req.PageSize), types.Installer.GetName, ) - if err != nil { return nil, trace.Wrap(err) } @@ -6201,10 +6199,13 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { userloginstatev1pb.RegisterUserLoginStateServiceServer(server, userLoginStateServer) recordingEncryptionService, err := recordingencryptionv1.NewService(recordingencryptionv1.ServiceConfig{ - Authorizer: cfg.Authorizer, - Uploader: cfg.AuthServer.Services, - KeyRotater: cfg.AuthServer.Services, - Logger: cfg.AuthServer.logger.With(teleport.ComponentKey, teleport.ComponentRecordingEncryption), + Authorizer: cfg.Authorizer, + Uploader: cfg.AuthServer.Services, + KeyRotater: cfg.AuthServer.Services, + Logger: cfg.AuthServer.logger.With(teleport.ComponentKey, teleport.ComponentRecordingEncryption), + SessionSummarizerProvider: cfg.APIConfig.AuthServer.sessionSummarizerProvider, + RecordingMetadataProvider: cfg.AuthServer.recordingMetadataProvider, + SessionStreamer: cfg.AuthServer, }) if err != nil { return nil, trace.Wrap(err) @@ -6335,6 +6336,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { ), Streamer: cfg.AuthServer, DownloadHandler: cfg.AuthServer, + Decrypter: cfg.AuthServer.EncryptedIO, }) if err != nil { return nil, trace.Wrap(err, "creating recording metadata service") diff --git a/lib/auth/init.go b/lib/auth/init.go index 401df3e49ef42..095ddf4624c5b 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -59,6 +59,7 @@ import ( "github.com/gravitational/teleport/lib/auth/migration" "github.com/gravitational/teleport/lib/auth/recordingencryption" "github.com/gravitational/teleport/lib/auth/recordingencryption/recordingencryptionv1" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/summarizer" "github.com/gravitational/teleport/lib/backend" @@ -403,6 +404,9 @@ type InitConfig struct { // plugin. The summarizer itself summarizes session recordings. SessionSummarizerProvider *summarizer.SessionSummarizerProvider + // RecordingMetadataProvider provides recording metadata for session recordings. + RecordingMetadataProvider *recordingmetadata.Provider + // RunWhileLockedRetryInterval defines the interval at which the auth server retries // a locking operation for backend objects. // This setting is particularly useful in test environments, diff --git a/lib/auth/recordingencryption/buf_decrypter.go b/lib/auth/recordingencryption/buf_decrypter.go new file mode 100644 index 0000000000000..7cbb6bf568501 --- /dev/null +++ b/lib/auth/recordingencryption/buf_decrypter.go @@ -0,0 +1,70 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package recordingencryption + +import ( + "bytes" + "context" + "io" + + "github.com/gravitational/trace" +) + +// ageEncryptionPrefix is the prefix used to identify age-encrypted data. +// age always uses "age-encryption.org/v1" as the prefix for its encrypted files. +const ageEncryptionPrefix = "age-encryption.org" + +var ageEncryptionPrefixBytes = []byte(ageEncryptionPrefix) + +// Decrypter wraps an io.Reader with decryption if the data is age-encrypted. +type Decrypter interface { + WithDecryption(ctx context.Context, reader io.Reader) (io.Reader, error) +} + +// DecryptBufferIfEncrypted checks whether the provided buffer contains +// age-encrypted data and decrypts it if necessary. +// +// The function looks for the standard age encryption header prefix to determine +// whether the data is encrypted. If the buffer is not age-encrypted, it returns +// the original data unchanged. +// +// If the buffer is encrypted, a Decrypter must be provided. The function uses +// the Decrypter to read and decrypt the buffer, returning the plaintext bytes. +// If no Decrypter is configured when encrypted data is detected, an error is +// returned. +func DecryptBufferIfEncrypted(ctx context.Context, buf []byte, decrypter Decrypter) ([]byte, error) { + if !bytes.HasPrefix(buf, ageEncryptionPrefixBytes) { + return buf, nil + } + + if decrypter == nil { + return nil, trace.BadParameter("recording metadata decrypter is not configured") + } + + decryptedReader, err := decrypter.WithDecryption(ctx, bytes.NewReader(buf)) + if err != nil { + return nil, trace.Wrap(err, "decrypting recording metadata") + } + + decryptedBuf := bytes.NewBuffer(make([]byte, 0, len(buf))) + _, err = io.Copy(decryptedBuf, decryptedReader) + if err != nil { + return nil, trace.Wrap(err, "reading decrypted recording metadata") + } + + return decryptedBuf.Bytes(), nil +} diff --git a/lib/auth/recordingencryption/encryptedio_test.go b/lib/auth/recordingencryption/encryptedio_test.go index 3ace3a06b1868..4498ef7f625b8 100644 --- a/lib/auth/recordingencryption/encryptedio_test.go +++ b/lib/auth/recordingencryption/encryptedio_test.go @@ -57,6 +57,8 @@ func TestEncryptedIO(t *testing.T) { err = writer.Close() require.NoError(t, err) + encryptedBytes := out.Bytes() + reader, err := encryptedIO.WithDecryption(ctx, out) require.NoError(t, err) @@ -65,6 +67,11 @@ func TestEncryptedIO(t *testing.T) { require.Equal(t, msg, plaintext) + // Test buffer decryption directly + decryptedBuf, err := recordingencryption.DecryptBufferIfEncrypted(ctx, encryptedBytes, encryptedIO) + require.NoError(t, err) + require.Equal(t, msg, decryptedBuf) + // creating an EncryptedIO without a SessionRecordingConfigGetter or keyfinder should be an error _, err = recordingencryption.NewEncryptedIO(nil, nil) require.Error(t, err) @@ -79,6 +86,12 @@ func TestEncryptedIO(t *testing.T) { _, err = encryptedIO.WithEncryption(ctx, &writeCloser{Writer: out}) require.ErrorIs(t, err, recordingencryption.ErrEncryptionDisabled) + + // Decrypting an unencrypted buffer should return the original buffer + origBuf := []byte("this is not encrypted") + decryptedBuf, err = recordingencryption.DecryptBufferIfEncrypted(ctx, origBuf, encryptedIO) + require.NoError(t, err) + require.Equal(t, origBuf, decryptedBuf) } type fakeSRCGetter struct { diff --git a/lib/auth/recordingencryption/recordingencryptionv1/service.go b/lib/auth/recordingencryption/recordingencryptionv1/service.go index 695857ae5328e..6a24d3d0a71b4 100644 --- a/lib/auth/recordingencryption/recordingencryptionv1/service.go +++ b/lib/auth/recordingencryption/recordingencryptionv1/service.go @@ -27,8 +27,11 @@ import ( "github.com/gravitational/teleport" recordingencryptionv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/recordingencryption/v1" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" + "github.com/gravitational/teleport/lib/auth/summarizer" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/events" + sessionpostprocessing "github.com/gravitational/teleport/lib/events/sessionpostprocessing" "github.com/gravitational/teleport/lib/session" ) @@ -46,6 +49,14 @@ type ServiceConfig struct { Logger *slog.Logger Uploader events.MultipartUploader KeyRotater KeyRotater + // SessionSummarizerProvider is a provider of the session summarizer service. + // It can be nil or provide a nil summarizer if summarization is not needed. + // The summarizer itself summarizes session recordings. + SessionSummarizerProvider *summarizer.SessionSummarizerProvider + // RecordingMetadataProvider is a provider of the recording metadata service. + RecordingMetadataProvider *recordingmetadata.Provider + // SessionStreamer is a streamer for session events. + SessionStreamer events.SessionStreamer } // NewService returns a new [Service] based on the given [ServiceConfig]. @@ -57,6 +68,12 @@ func NewService(cfg ServiceConfig) (*Service, error) { return nil, trace.BadParameter("uploader is required") case cfg.KeyRotater == nil: return nil, trace.BadParameter("key rotater is required") + case cfg.SessionStreamer == nil: + return nil, trace.BadParameter("session streamer is required") + case cfg.RecordingMetadataProvider == nil: + return nil, trace.BadParameter("recording metadata provider is required") + case cfg.SessionSummarizerProvider == nil: + return nil, trace.BadParameter("session summarizer provider is required") } if cfg.Logger == nil { @@ -64,10 +81,13 @@ func NewService(cfg ServiceConfig) (*Service, error) { } return &Service{ - logger: cfg.Logger, - uploader: cfg.Uploader, - auth: cfg.Authorizer, - rotater: cfg.KeyRotater, + logger: cfg.Logger, + uploader: cfg.Uploader, + auth: cfg.Authorizer, + rotater: cfg.KeyRotater, + sessionSummarizerProvider: cfg.SessionSummarizerProvider, + recordingMetadataProvider: cfg.RecordingMetadataProvider, + streamer: cfg.SessionStreamer, }, nil } @@ -79,6 +99,13 @@ type Service struct { logger *slog.Logger uploader events.MultipartUploader rotater KeyRotater + // SessionSummarizerProvider is a provider of the session summarizer service. + // It can be nil or provide a nil summarizer if summarization is not needed. + // The summarizer itself summarizes session recordings. + sessionSummarizerProvider *summarizer.SessionSummarizerProvider + // RecordingMetadataProvider is a provider of the recording metadata service. + recordingMetadataProvider *recordingmetadata.Provider + streamer events.SessionStreamer } func streamUploadAsProto(upload events.StreamUpload) *recordingencryptionv1.Upload { @@ -201,6 +228,23 @@ func (s *Service) CompleteUpload(ctx context.Context, req *recordingencryptionv1 return nil, trace.Wrap(err) } + sessionEnd, err := events.FindSessionEndEvent(ctx, s.streamer, upload.SessionID) + if err != nil || sessionEnd == nil { + return &recordingencryptionv1.CompleteUploadResponse{}, nil + } + + if err := sessionpostprocessing.Process( + ctx, + sessionpostprocessing.Config{ + SessionEnd: sessionEnd, + SessionID: upload.SessionID, + SessionSummarizerProvider: s.sessionSummarizerProvider, + RecordingMetadataProvider: s.recordingMetadataProvider, + }, + ); err != nil { + s.logger.WarnContext(ctx, "session post-processing failed", "error", err) + } + return &recordingencryptionv1.CompleteUploadResponse{}, nil } diff --git a/lib/auth/recordingencryption/recordingencryptionv1/service_test.go b/lib/auth/recordingencryption/recordingencryptionv1/service_test.go index e233fb40f0cdf..5f248913ff169 100644 --- a/lib/auth/recordingencryption/recordingencryptionv1/service_test.go +++ b/lib/auth/recordingencryption/recordingencryptionv1/service_test.go @@ -21,15 +21,24 @@ import ( "errors" "fmt" "testing" + "time" "github.com/google/uuid" "github.com/gravitational/trace" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" recordingencryptionv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/recordingencryption/v1" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/auth/recordingencryption/recordingencryptionv1" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" + "github.com/gravitational/teleport/lib/auth/summarizer" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils/log/logtest" ) @@ -60,10 +69,13 @@ func TestRotateKey(t *testing.T) { ctx := withAuthCtx(t.Context(), c.ctx) rotater := newFakeKeyRotater() cfg := recordingencryptionv1.ServiceConfig{ - Authorizer: &fakeAuthorizer{}, - Logger: logtest.NewLogger(), - Uploader: fakeUploader{}, - KeyRotater: rotater, + Authorizer: &fakeAuthorizer{}, + Logger: logtest.NewLogger(), + Uploader: fakeUploader{}, + KeyRotater: rotater, + SessionStreamer: &fakeSessionStreamer{}, + RecordingMetadataProvider: recordingmetadata.NewProvider(), + SessionSummarizerProvider: summarizer.NewSessionSummarizerProvider(), } service, err := recordingencryptionv1.NewService(cfg) @@ -78,7 +90,6 @@ func TestRotateKey(t *testing.T) { require.NoError(t, err) require.Len(t, rotater.keys, 2) } - }) } } @@ -105,10 +116,13 @@ func TestCompleteRotation(t *testing.T) { ctx := withAuthCtx(t.Context(), c.ctx) rotater := newFakeKeyRotater() cfg := recordingencryptionv1.ServiceConfig{ - Authorizer: &fakeAuthorizer{}, - Logger: logtest.NewLogger(), - Uploader: fakeUploader{}, - KeyRotater: rotater, + Authorizer: &fakeAuthorizer{}, + Logger: logtest.NewLogger(), + Uploader: fakeUploader{}, + KeyRotater: rotater, + SessionStreamer: &fakeSessionStreamer{}, + RecordingMetadataProvider: recordingmetadata.NewProvider(), + SessionSummarizerProvider: summarizer.NewSessionSummarizerProvider(), } service, err := recordingencryptionv1.NewService(cfg) @@ -152,10 +166,13 @@ func TestRollbackRotation(t *testing.T) { ctx := withAuthCtx(t.Context(), c.ctx) rotater := newFakeKeyRotater() cfg := recordingencryptionv1.ServiceConfig{ - Authorizer: &fakeAuthorizer{}, - Logger: logtest.NewLogger(), - Uploader: fakeUploader{}, - KeyRotater: rotater, + Authorizer: &fakeAuthorizer{}, + Logger: logtest.NewLogger(), + Uploader: fakeUploader{}, + KeyRotater: rotater, + SessionStreamer: &fakeSessionStreamer{}, + RecordingMetadataProvider: recordingmetadata.NewProvider(), + SessionSummarizerProvider: summarizer.NewSessionSummarizerProvider(), } service, err := recordingencryptionv1.NewService(cfg) @@ -198,10 +215,13 @@ func TestGetRotationState(t *testing.T) { ctx := withAuthCtx(t.Context(), c.ctx) rotater := newFakeKeyRotater() cfg := recordingencryptionv1.ServiceConfig{ - Authorizer: &fakeAuthorizer{}, - Logger: logtest.NewLogger(), - Uploader: fakeUploader{}, - KeyRotater: rotater, + Authorizer: &fakeAuthorizer{}, + Logger: logtest.NewLogger(), + Uploader: fakeUploader{}, + KeyRotater: rotater, + SessionStreamer: &fakeSessionStreamer{}, + RecordingMetadataProvider: recordingmetadata.NewProvider(), + SessionSummarizerProvider: summarizer.NewSessionSummarizerProvider(), } service, err := recordingencryptionv1.NewService(cfg) @@ -229,6 +249,10 @@ type fakeUploader struct { events.MultipartUploader } +func (f fakeUploader) CompleteUpload(ctx context.Context, upload events.StreamUpload, parts []events.StreamPart) error { + return nil +} + type fakeAuthorizer struct{} func (f *fakeAuthorizer) Authorize(ctx context.Context) (*authz.Context, error) { @@ -300,3 +324,100 @@ func (f *fakeKeyRotater) RollbackRotation(ctx context.Context) error { func (f *fakeKeyRotater) GetRotationState(ctx context.Context) ([]*recordingencryptionv1pb.FingerprintWithState, error) { return f.keys, nil } + +type fakeSessionStreamer struct{} + +func (f *fakeSessionStreamer) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) { + returnChan := make(chan apievents.AuditEvent, 1) + errChan := make(chan error, 1) + close(errChan) + events := eventstest.GenerateTestSession(eventstest.SessionParams{ + UserName: "alice", + SessionID: string(sessionID), + ServerID: "testcluster", + PrintData: []string{"net", "stat"}, + }) + returnChan <- events[len(events)-1] + return returnChan, nil +} + +func TestSessionCompleter(t *testing.T) { + sessionID := session.ID(uuid.NewString()) + + metadataProvider := recordingmetadata.NewProvider() + recorderMetadata := &fakeRecordingMetadata{} + recorderMetadata.On("ProcessSessionRecording", mock.Anything, sessionID, mock.Anything). + Return(nil).Once() + metadataProvider.SetService(recorderMetadata) + + summarizerProvider := summarizer.NewSessionSummarizerProvider() + sessionSummarizer := &fakeSummarizer{} + sessionSummarizer.On("SummarizeSSH", mock.Anything, mock.Anything). + Return(nil).Once() + + summarizerProvider.SetSummarizer(sessionSummarizer) + cfg := recordingencryptionv1.ServiceConfig{ + Authorizer: &fakeAuthorizer{}, + Logger: logtest.NewLogger(), + Uploader: fakeUploader{}, + KeyRotater: newFakeKeyRotater(), + SessionStreamer: &fakeSessionStreamer{}, + RecordingMetadataProvider: metadataProvider, + SessionSummarizerProvider: summarizerProvider, + } + + service, err := recordingencryptionv1.NewService(cfg) + require.NoError(t, err) + + ctx := withAuthCtx(t.Context(), newServiceAuthCtx()) + _, err = service.CompleteUpload(ctx, &recordingencryptionv1pb.CompleteUploadRequest{ + Upload: &recordingencryptionv1pb.Upload{ + SessionId: string(sessionID), + InitiatedAt: timestamppb.Now(), + UploadId: uuid.NewString(), + }, + }) + require.NoError(t, err) + + recorderMetadata.AssertExpectations(t) + sessionSummarizer.AssertExpectations(t) +} + +type fakeRecordingMetadata struct { + mock.Mock +} + +func (f *fakeRecordingMetadata) ProcessSessionRecording(ctx context.Context, sessionID session.ID, duration time.Duration) error { + args := f.Called(ctx, sessionID, duration) + return args.Error(0) +} + +type fakeSummarizer struct { + mock.Mock +} + +func (f *fakeSummarizer) SummarizeSSH(ctx context.Context, sessionEndEvent *apievents.SessionEnd) error { + args := f.Called(ctx, sessionEndEvent) + return args.Error(0) +} + +func (f *fakeSummarizer) SummarizeDatabase(ctx context.Context, sessionEndEvent *apievents.DatabaseSessionEnd) error { + args := f.Called(ctx, sessionEndEvent) + return args.Error(0) +} + +func (f *fakeSummarizer) SummarizeWithoutEndEvent(ctx context.Context, sessionID session.ID) error { + args := f.Called(ctx, sessionID) + return args.Error(0) +} + +func newServiceAuthCtx() authz.Context { + return authz.Context{ + Identity: authz.BuiltinRole{ + Role: types.RoleProxy, + }, + UnmappedIdentity: authz.BuiltinRole{ + Role: types.RoleProxy, + }, + } +} diff --git a/lib/auth/recordingmetadata/recordingmetadatav1/recordingmetadata.go b/lib/auth/recordingmetadata/recordingmetadatav1/recordingmetadata.go index d6f682a5acb99..cdf0d8715ba3b 100644 --- a/lib/auth/recordingmetadata/recordingmetadatav1/recordingmetadata.go +++ b/lib/auth/recordingmetadata/recordingmetadatav1/recordingmetadata.go @@ -21,6 +21,7 @@ package recordingmetadatav1 import ( "bytes" "context" + "errors" "io" "log/slog" "math" @@ -39,6 +40,7 @@ import ( "github.com/gravitational/teleport" pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/recordingmetadata/v1" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth/recordingencryption" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/player" "github.com/gravitational/teleport/lib/session" @@ -65,6 +67,7 @@ type RecordingMetadataService struct { streamer player.Streamer uploadHandler UploadHandler concurrencyLimiter *semaphore.Weighted + encrypter events.EncryptionWrapper } // RecordingMetadataServiceConfig defines the configuration for the RecordingMetadataService. @@ -73,6 +76,8 @@ type RecordingMetadataServiceConfig struct { Streamer player.Streamer // UploadHandler is used to upload session metadata and thumbnails. UploadHandler UploadHandler + // Encrypter is used to encrypt session metadata and thumbnails. + Encrypter events.EncryptionWrapper } const ( @@ -100,6 +105,7 @@ func NewRecordingMetadataService(cfg RecordingMetadataServiceConfig) (*Recording uploadHandler: cfg.UploadHandler, logger: slog.With(teleport.ComponentKey, "recording_metadata"), concurrencyLimiter: semaphore.NewWeighted(concurrencyLimit), + encrypter: cfg.Encrypter, }, nil } @@ -152,7 +158,11 @@ func (s *RecordingMetadataService) ProcessSessionRecording(ctx context.Context, var finish sync.Once w, cancelUpload, uploadErrs := s.startUpload(ctx, sessionID) - + select { + case err := <-uploadErrs: + return trace.Wrap(err) + default: + } defer func() { finish.Do(func() { cancelUpload() @@ -396,7 +406,27 @@ func (s *RecordingMetadataService) startUpload(ctx context.Context, sessionID se uploadCtx, cancel := context.WithCancel(ctx) r, w := io.Pipe() errs := make(chan error, 1) - + var writer io.WriteCloser = w + if s.encrypter != nil { + // wrap the pipe writer with encryption + // WithEncryption will never close the underlying writer when the returned + // WriteCloser is closed, so we need to create a multiCloser to close both + // the encrypted writer and the pipe writer. + encrypted, err := s.encrypter.WithEncryption(uploadCtx, w) + switch { + case err == nil: + writer = &multiCloser{ + WriteCloser: encrypted, + pipeCloser: w, + } + case errors.Is(err, recordingencryption.ErrEncryptionDisabled): + // if encryption isn't enabled, do nothing + default: + cancel() + errs <- trace.Wrap(err, "starting recording encrypter") + return nil, nil, errs + } + } go func() { defer r.Close() @@ -410,7 +440,21 @@ func (s *RecordingMetadataService) startUpload(ctx context.Context, sessionID se errs <- nil }() - return w, cancel, errs + return writer, cancel, errs +} + +// multiCloser is an io.WriteCloser that closes the underlying +// WriteCloser and an additional Closer. +type multiCloser struct { + io.WriteCloser + pipeCloser io.Closer +} + +func (m *multiCloser) Close() error { + // flush the encryption writer and close the pipe + errEncryption := m.WriteCloser.Close() + errPipe := m.pipeCloser.Close() + return trace.NewAggregate(errEncryption, errPipe) } func (s *RecordingMetadataService) uploadThumbnail(ctx context.Context, sessionID session.ID, thumbnail *pb.SessionRecordingThumbnail) error { @@ -423,7 +467,28 @@ func (s *RecordingMetadataService) uploadThumbnail(ctx context.Context, sessionI return trace.Wrap(err) } - path, err := s.uploadHandler.UploadThumbnail(ctx, sessionID, bytes.NewReader(b)) + var buf io.Reader = bytes.NewReader(b) + if s.encrypter != nil { + writeBuffer := bytes.NewBuffer(nil) + encryptedWriter, err := s.encrypter.WithEncryption(ctx, &nopCloser{writeBuffer}) + switch { + case err == nil: + if _, err := io.Copy(encryptedWriter, buf); err != nil { + encryptedWriter.Close() + return trace.Wrap(err) + } + if err := encryptedWriter.Close(); err != nil { + return trace.Wrap(err) + } + buf = writeBuffer + case errors.Is(err, recordingencryption.ErrEncryptionDisabled): + // if encryption isn't enabled, do nothing + default: + return trace.Wrap(err, "starting recording encrypter") + } + } + + path, err := s.uploadHandler.UploadThumbnail(ctx, sessionID, buf) if err != nil { return trace.Wrap(err) } @@ -433,6 +498,14 @@ func (s *RecordingMetadataService) uploadThumbnail(ctx context.Context, sessionI return nil } +type nopCloser struct { + io.Writer +} + +func (n nopCloser) Close() error { + return nil +} + // getRandomThumbnailTime returns the ideal time offset for capturing a thumbnail // within the session duration based on the provided interval. // It avoids the first and last 20% of the session recording to increase the chances of diff --git a/lib/auth/recordingmetadata/recordingmetadatav1/service.go b/lib/auth/recordingmetadata/recordingmetadatav1/service.go index cee15ff2f3dda..5237da4c01610 100644 --- a/lib/auth/recordingmetadata/recordingmetadatav1/service.go +++ b/lib/auth/recordingmetadata/recordingmetadatav1/service.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport" pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/recordingmetadata/v1" + "github.com/gravitational/teleport/lib/auth/recordingencryption" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/player" "github.com/gravitational/teleport/lib/session" @@ -45,6 +46,7 @@ type Service struct { authorizer Authorizer streamer player.Streamer downloadHandler DownloadHandler + decrypter events.DecryptionWrapper logger *slog.Logger } @@ -70,6 +72,8 @@ type ServiceConfig struct { Streamer player.Streamer // DownloadHandler is used to handle uploads and downloads of session recording metadata and thumbnails. DownloadHandler DownloadHandler + // Decrypter is used to decrypt session metadata and thumbnails. + Decrypter events.DecryptionWrapper } // NewService creates a new instance of the recording metadata service. @@ -86,6 +90,7 @@ func NewService(cfg ServiceConfig) (*Service, error) { streamer: cfg.Streamer, downloadHandler: cfg.DownloadHandler, logger: slog.With(teleport.ComponentKey, "recording_metadata"), + decrypter: cfg.Decrypter, }, nil } @@ -102,8 +107,13 @@ func (r *Service) GetThumbnail(ctx context.Context, req *pb.GetThumbnailRequest) return nil, trace.Wrap(err) } + payload, err := r.decryptIfNeeded(ctx, buf.Bytes()) + if err != nil { + return nil, trace.Wrap(err, "decrypting session recording thumbnail") + } + thumbnail := &pb.SessionRecordingThumbnail{} - err = proto.Unmarshal(buf.Bytes(), thumbnail) + err = proto.Unmarshal(payload, thumbnail) if err != nil { return nil, trace.Wrap(err) } @@ -125,7 +135,11 @@ func (r *Service) GetMetadata(req *pb.GetMetadataRequest, stream grpc.ServerStre return trace.Wrap(err) } - reader := bufio.NewReader(bytes.NewReader(buf.Bytes())) + payload, err := r.decryptIfNeeded(stream.Context(), buf.Bytes()) + if err != nil { + return trace.Wrap(err, "decrypting session recording thumbnail") + } + reader := bufio.NewReader(bytes.NewReader(payload)) for { msgBytes, err := readDelimitedMessage(reader) @@ -218,3 +232,9 @@ func readDelimitedMessage(r *bufio.Reader) ([]byte, error) { return msgBytes, nil } + +// decryptIfNeeded decrypts the data if it is encrypted. If the data is not encrypted, it is returned as-is. +func (r *Service) decryptIfNeeded(ctx context.Context, data []byte) ([]byte, error) { + unencrypted, err := recordingencryption.DecryptBufferIfEncrypted(ctx, data, r.decrypter) + return unencrypted, trace.Wrap(err) +} diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 2f19850ee4953..c04a82eccce1d 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -40,7 +40,10 @@ import ( auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" "github.com/gravitational/teleport/api/internalutils/stream" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" + "github.com/gravitational/teleport/lib/auth/summarizer" "github.com/gravitational/teleport/lib/defaults" + sessionpostprocessing "github.com/gravitational/teleport/lib/events/sessionpostprocessing" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" @@ -255,6 +258,11 @@ type AuditLogConfig struct { // Decrypter wraps session replay with decryption Decrypter DecryptionWrapper + + // SessionSummarizerProvider provides session summarizers + SessionSummarizerProvider *summarizer.SessionSummarizerProvider + // RecordingMetadataProvider provides recording metadata service + RecordingMetadataProvider *recordingmetadata.Provider } // CheckAndSetDefaults checks and sets defaults @@ -281,6 +289,12 @@ func (a *AuditLogConfig) CheckAndSetDefaults() error { mask := os.FileMode(teleport.DirMaskSharedGroup) a.DirMask = &mask } + if a.SessionSummarizerProvider == nil { + a.SessionSummarizerProvider = summarizer.NewSessionSummarizerProvider() + } + if a.RecordingMetadataProvider == nil { + a.RecordingMetadataProvider = recordingmetadata.NewProvider() + } if (a.GID != nil && a.UID == nil) || (a.UID != nil && a.GID == nil) { return trace.BadParameter("if UID or GID is set, both should be specified") } @@ -674,8 +688,27 @@ func (l *AuditLog) UploadEncryptedRecording(ctx context.Context, sessionID strin part = nextPart partNumber++ } + err = l.UploadHandler.CompleteUpload(ctx, *upload, streamParts) + if err != nil { + return trace.Wrap(err, "completing upload") + } - return trace.Wrap(l.UploadHandler.CompleteUpload(ctx, *upload, streamParts), "completing upload") + sessionEnd, err := FindSessionEndEvent(ctx, l, session.ID(sessionID)) + if err != nil || sessionEnd == nil { + return nil + } + if err := sessionpostprocessing.Process( + ctx, + sessionpostprocessing.Config{ + SessionEnd: sessionEnd, + SessionID: upload.SessionID, + SessionSummarizerProvider: l.SessionSummarizerProvider, + RecordingMetadataProvider: l.RecordingMetadataProvider, + }, + ); err != nil { + l.log.WarnContext(ctx, "session post-processing failed", "error", err) + } + return nil } // getLocalLog returns the local (file based) AuditLogger. diff --git a/lib/events/auditlog_test.go b/lib/events/auditlog_test.go index a1e2d4c65daa7..d5a97c73dbb23 100644 --- a/lib/events/auditlog_test.go +++ b/lib/events/auditlog_test.go @@ -32,10 +32,13 @@ import ( "github.com/google/uuid" "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" + "github.com/gravitational/teleport/lib/auth/summarizer" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/session" @@ -421,3 +424,113 @@ func TestPadUploadPart(t *testing.T) { _, err = r.Read(nil) require.ErrorIs(t, err, io.EOF) } + +func TestCallingSummarizerMetadata(t *testing.T) { + ctx := t.Context() + + parts := generateParts(t) + sessionID, err := uuid.NewV7() + require.NoError(t, err) + metadataProvider := recordingmetadata.NewProvider() + recorderMetadata := &fakeRecordingMetadata{} + recorderMetadata.On("ProcessSessionRecording", mock.Anything, session.ID(sessionID.String()), mock.Anything). + Return(nil).Once() + metadataProvider.SetService(recorderMetadata) + + summarizerProvider := summarizer.NewSessionSummarizerProvider() + sessionSummarizer := &fakeSummarizer{} + sessionSummarizer.On("SummarizeSSH", mock.Anything, mock.Anything). + Return(nil).Once() + summarizerProvider.SetSummarizer(sessionSummarizer) + + uploader := eventstest.NewMemoryUploader() + alog, err := events.NewAuditLog(events.AuditLogConfig{ + DataDir: t.TempDir(), + ServerID: "server1", + UploadHandler: uploader, + SessionSummarizerProvider: summarizerProvider, + RecordingMetadataProvider: metadataProvider, + }) + require.NoError(t, err) + defer alog.Close() + + partIter := func(yield func([]byte, error) bool) { + for _, part := range parts { + if part == nil { + if !yield(nil, errors.New("invalid part")) { + return + } + } else { + if !yield(part, nil) { + return + } + } + } + } + + err = alog.UploadEncryptedRecording(ctx, sessionID.String(), partIter) + require.NoError(t, err) + + recorderMetadata.AssertExpectations(t) + sessionSummarizer.AssertExpectations(t) +} + +type fakeRecordingMetadata struct { + mock.Mock +} + +func (f *fakeRecordingMetadata) ProcessSessionRecording(ctx context.Context, sessionID session.ID, duration time.Duration) error { + args := f.Called(ctx, sessionID, duration) + return args.Error(0) +} + +type fakeSummarizer struct { + mock.Mock +} + +func (f *fakeSummarizer) SummarizeSSH(ctx context.Context, sessionEndEvent *apievents.SessionEnd) error { + args := f.Called(ctx, sessionEndEvent) + return args.Error(0) +} + +func (f *fakeSummarizer) SummarizeDatabase(ctx context.Context, sessionEndEvent *apievents.DatabaseSessionEnd) error { + args := f.Called(ctx, sessionEndEvent) + return args.Error(0) +} + +func (f *fakeSummarizer) SummarizeWithoutEndEvent(ctx context.Context, sessionID session.ID) error { + args := f.Called(ctx, sessionID) + return args.Error(0) +} + +func generateParts(t *testing.T) [][]byte { + uploader := eventstest.NewMemoryUploader() + + ctx := t.Context() + sid := session.NewID() + sessionEvents := eventstest.GenerateTestSession(eventstest.SessionParams{ + PrintEvents: 1000, + UserName: "alice", + SessionID: string(sid), + ServerID: "testcluster", + PrintData: []string{"net", "stat"}, + }) + + streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ + Uploader: uploader, + }) + require.NoError(t, err) + stream, err := streamer.CreateAuditStream(ctx, sid) + require.NoError(t, err) + for _, event := range sessionEvents { + require.NoError(t, stream.RecordEvent(ctx, eventstest.PrepareEvent(event))) + } + require.NoError(t, stream.Complete(ctx)) + + uploads, err := uploader.ListUploads(ctx) + require.NoError(t, err) + require.Len(t, uploads, 1) + parts, err := uploader.GetParts(uploads[0].ID) + require.NoError(t, err) + return parts +} diff --git a/lib/events/complete.go b/lib/events/complete.go index b65f283cf70c5..a33cf6d27526d 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -36,6 +36,8 @@ import ( "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" + "github.com/gravitational/teleport/lib/auth/summarizer" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -69,6 +71,12 @@ type UploadCompleterConfig struct { Clock clockwork.Clock // ClusterName identifies the originating teleport cluster ClusterName string + // SessionSummarizerProvider is a provider of the session summarizer service. + // It can be nil or provide a nil summarizer if summarization is not needed. + // The summarizer itself summarizes session recordings. + SessionSummarizerProvider *summarizer.SessionSummarizerProvider + // RecordingMetadataProvider is a provider of the recording metadata service. + RecordingMetadataProvider *recordingmetadata.Provider } // CheckAndSetDefaults checks and sets default values @@ -94,6 +102,12 @@ func (cfg *UploadCompleterConfig) CheckAndSetDefaults() error { if cfg.Clock == nil { cfg.Clock = clockwork.NewRealClock() } + if cfg.RecordingMetadataProvider == nil { + cfg.RecordingMetadataProvider = &recordingmetadata.Provider{} + } + if cfg.SessionSummarizerProvider == nil { + cfg.SessionSummarizerProvider = &summarizer.SessionSummarizerProvider{} + } return nil } @@ -348,6 +362,8 @@ func (u *UploadCompleter) ensureSessionEndEvent(ctx context.Context, uploadData // We use the streaming events API to search through the session events, because it works // for both Desktop and SSH sessions var lastEvent events.AuditEvent + var startTime time.Time + var isPTYSession bool ctx, cancel := context.WithCancel(ctx) defer cancel() evts, errors := u.cfg.AuditLog.StreamSessionEvents(ctx, uploadData.SessionID, 0) @@ -368,6 +384,7 @@ loop: return nil case *events.WindowsDesktopSessionStart: + startTime = e.Time desktopSessionEnd.Type = WindowsDesktopSessionEndEvent desktopSessionEnd.Code = DesktopSessionEndCode desktopSessionEnd.ClusterName = e.ClusterName @@ -383,6 +400,8 @@ loop: desktopSessionEnd.DesktopName = fmt.Sprintf("%v (recovered)", e.DesktopName) case *events.SessionStart: + isPTYSession = true + startTime = e.Time sshSessionEnd.Type = SessionEndEvent sshSessionEnd.Code = SessionEndCode sshSessionEnd.ClusterName = e.ClusterName @@ -442,6 +461,28 @@ loop: if err := u.cfg.AuditLog.EmitAuditEvent(ctx, sessionEndEvent); err != nil { return trace.Wrap(err) } + + if !isPTYSession { + return nil + } + + // For PTY sessions, process recording metadata and summarization. + recordingMetadata := u.cfg.RecordingMetadataProvider.Service() + if !startTime.IsZero() && !sessionEndEvent.GetTime().IsZero() { + duration := sessionEndEvent.GetTime().Sub(startTime) + if err := recordingMetadata.ProcessSessionRecording(ctx, uploadData.SessionID, duration); err != nil { + slog.WarnContext(ctx, "Failed to process session recording metadata", "error", err) + } + } else { + slog.WarnContext(ctx, "Session start or end time is not set, skipping recording metadata processing") + } + + summarizer := u.cfg.SessionSummarizerProvider.SessionSummarizer() + if err := summarizer.SummarizeSSH(ctx, &sshSessionEnd); err != nil { + slog.WarnContext(ctx, "Failed to summarize upload", "error", err) + return trace.Wrap(err) + } + return nil } diff --git a/lib/events/eventstest/generate.go b/lib/events/eventstest/generate.go index abafb943174c1..f6b59f44ee2ce 100644 --- a/lib/events/eventstest/generate.go +++ b/lib/events/eventstest/generate.go @@ -171,6 +171,116 @@ func GenerateTestSession(params SessionParams) []apievents.AuditEvent { return genEvents } +// GenerateTestKubeSession generates Kubernetes test session events starting +// with session start event, adds printEvents events and returns the result. +func GenerateTestKubeSession(params SessionParams) []apievents.AuditEvent { + params.SetDefaults() + connectionMetadata := apievents.ConnectionMetadata{ + LocalAddr: "127.0.0.1:3022", + RemoteAddr: "[::1]:37718", + Protocol: events.EventProtocolKube, + } + kubernetesClusterMetadata := apievents.KubernetesClusterMetadata{ + KubernetesCluster: "my-kube-cluster", + KubernetesUsers: []string{"admin"}, + KubernetesGroups: []string{"viewers"}, + KubernetesLabels: map[string]string{ + "teleport.internal/resource-id": "ed910b7b-fe3b-4959-bf2e-ac45f4648f2a", + }, + } + kubernetesPodMetadata := apievents.KubernetesPodMetadata{ + KubernetesPodName: "simple-shell-pod", + KubernetesPodNamespace: "default", + KubernetesContainerName: "shell-container", + KubernetesContainerImage: "busybox", + KubernetesNodeName: "docker-desktop", + } + sessionStart := apievents.SessionStart{ + Metadata: apievents.Metadata{ + Index: 0, + Type: events.SessionStartEvent, + ID: "36cee9e9-9a80-4c32-9163-3d9241cdac7a", + Code: events.SessionStartCode, + Time: params.Clock.Now().UTC(), + ClusterName: params.ClusterName, + }, + ServerMetadata: apievents.ServerMetadata{ + ServerVersion: teleport.Version, + ServerID: params.ServerID, + ServerLabels: map[string]string{ + "teleport.internal/resource-id": "ed910b7b-fe3b-4959-bf2e-ac45f4648f2a", + }, + ServerHostname: "planet", + ServerNamespace: "default", + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: params.SessionID, + }, + UserMetadata: apievents.UserMetadata{ + User: params.UserName, + Login: "bob", + }, + ConnectionMetadata: connectionMetadata, + TerminalSize: "80:25", + KubernetesClusterMetadata: kubernetesClusterMetadata, + KubernetesPodMetadata: kubernetesPodMetadata, + } + + sessionEnd := apievents.SessionEnd{ + Metadata: apievents.Metadata{ + Index: 20, + Type: events.SessionEndEvent, + ID: "da455e0f-c27d-459f-a218-4e83b3db9426", + Code: events.SessionEndCode, + Time: params.Clock.Now().UTC().Add(time.Hour + time.Second + 7*time.Millisecond), + ClusterName: params.ClusterName, + }, + ServerMetadata: apievents.ServerMetadata{ + ServerVersion: teleport.Version, + ServerID: params.ServerID, + ServerNamespace: "default", + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: params.SessionID, + }, + UserMetadata: apievents.UserMetadata{ + User: params.UserName, + }, + ConnectionMetadata: connectionMetadata, + EnhancedRecording: true, + Interactive: true, + Participants: []string{params.UserName}, + StartTime: params.Clock.Now().UTC(), + EndTime: params.Clock.Now().UTC().Add(3*time.Hour + time.Second + 7*time.Millisecond), + KubernetesClusterMetadata: kubernetesClusterMetadata, + KubernetesPodMetadata: kubernetesPodMetadata, + } + + genEvents := []apievents.AuditEvent{&sessionStart} + for i, data := range params.PrintData { + event := &apievents.SessionPrint{ + Metadata: apievents.Metadata{ + Index: int64(i) + 1, + Type: events.SessionPrintEvent, + Time: params.Clock.Now().UTC().Add(time.Minute + time.Duration(i)*time.Millisecond), + }, + ChunkIndex: int64(i), + DelayMilliseconds: int64(i), + Offset: int64(i), + Data: []byte(data), + } + event.Bytes = int64(len(event.Data)) + event.Time = event.Time.Add(time.Duration(i) * time.Millisecond) + + genEvents = append(genEvents, event) + } + + sessionEnd.Metadata.Index = int64(len(genEvents)) + genEvents = append(genEvents, &sessionEnd) + + return genEvents +} + // DBSessionParams specifies optional parameters // for a generated database session. type DBSessionParams struct { diff --git a/lib/events/sessionend.go b/lib/events/sessionend.go new file mode 100644 index 0000000000000..f8533dd6798e6 --- /dev/null +++ b/lib/events/sessionend.go @@ -0,0 +1,75 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package events + +import ( + "context" + + "github.com/gravitational/trace" + + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/session" +) + +// FindSessionEndEvent streams session events to find the session end event for the given session ID. +// It returns: +// - SessionEnd +// - DatabaseSessionEnd +// - WindowsDesktopSessionEnd +// - AppSessionEnd +// - MCPSessionEnd, +// or nil if none is found. +// TODO(tigrato): Revisit this approach for large sessions, as it's highly inefficient. +// Instead, consider downloading the last few parts of the recording to find the session end event +// instead of streaming all events. +func FindSessionEndEvent(ctx context.Context, streamer SessionStreamer, sessionID session.ID) (apievents.AuditEvent, error) { + switch { + case streamer == nil: + return nil, trace.BadParameter("session streamer is required") + case sessionID == "": + return nil, trace.BadParameter("session ID is required") + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + eventsCh, errCh := streamer.StreamSessionEvents(ctx, sessionID, 0) + for { + select { + case event, ok := <-eventsCh: + if !ok { + return nil, trace.NotFound("session end event not found") + } + switch e := event.(type) { + case *apievents.WindowsDesktopSessionEnd: + return e, nil + case *apievents.SessionEnd: + return e, nil + case *apievents.DatabaseSessionEnd: + return e, nil + case *apievents.AppSessionEnd: + return e, nil + case *apievents.MCPSessionEnd: + return e, nil + } + case err := <-errCh: + return nil, trace.Wrap(err) + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err()) + } + } +} diff --git a/lib/events/sessionend_test.go b/lib/events/sessionend_test.go new file mode 100644 index 0000000000000..f83de86f2e04f --- /dev/null +++ b/lib/events/sessionend_test.go @@ -0,0 +1,131 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package events_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/session" +) + +func TestFindSessionEndEvent(t *testing.T) { + uploader := eventstest.NewMemoryUploader() + alog, err := events.NewAuditLog(events.AuditLogConfig{ + DataDir: t.TempDir(), + ServerID: "server1", + UploadHandler: uploader, + }) + require.NoError(t, err) + t.Cleanup(func() { alog.Close() }) + + sshSessionID := session.NewID() + sshSessionEvents := eventstest.GenerateTestSession(eventstest.SessionParams{ + PrintEvents: 1000, + UserName: "bob", + SessionID: string(sshSessionID), + ServerID: "testcluster", + PrintData: []string{"ls", "whoami"}, + }) + + kubernetesSessionID := session.NewID() + kubernetesSessionEvents := eventstest.GenerateTestKubeSession(eventstest.SessionParams{ + UserName: "carol", + SessionID: string(kubernetesSessionID), + ServerID: "testcluster", + PrintData: []string{"get pods", "describe pod"}, + PrintEvents: 1000, + }) + + databaseSessionID := session.NewID() + databaseSessionEvents := eventstest.GenerateTestDBSession(eventstest.DBSessionParams{ + UserName: "dave", + SessionID: string(databaseSessionID), + ServerID: "testcluster", + Queries: 1000, + }) + + tests := []struct { + name string // description of this test case + sessionID session.ID + auditEvents []apievents.AuditEvent + want apievents.AuditEvent + assertErr require.ErrorAssertionFunc + }{ + { + name: "SSH session with SessionEnd event", + sessionID: sshSessionID, + auditEvents: sshSessionEvents, + want: sshSessionEvents[len(sshSessionEvents)-1], + assertErr: require.NoError, + }, + { + name: "Kubernetes session with KubeSessionEnd event", + sessionID: kubernetesSessionID, + auditEvents: kubernetesSessionEvents, + want: kubernetesSessionEvents[len(kubernetesSessionEvents)-1], + assertErr: require.NoError, + }, + { + name: "Database session with DBSessionEnd event", + sessionID: databaseSessionID, + auditEvents: databaseSessionEvents, + want: databaseSessionEvents[len(databaseSessionEvents)-1], + assertErr: require.NoError, + }, + { + name: "No session end event", + sessionID: session.NewID(), + auditEvents: eventstest.GenerateTestSession(eventstest.SessionParams{PrintEvents: 10})[:9], + assertErr: require.Error, + }, + { + name: "missing session ID", + sessionID: session.NewID(), + want: nil, + assertErr: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if len(tt.auditEvents) > 0 { + streamer, err := events.NewProtoStreamer( + events.ProtoStreamerConfig{ + Uploader: uploader, + }, + ) + require.NoError(t, err) + + stream, err := streamer.CreateAuditStream(t.Context(), tt.sessionID) + require.NoError(t, err) + for _, event := range tt.auditEvents { + require.NoError(t, stream.RecordEvent(t.Context(), eventstest.PrepareEvent(event))) + } + require.NoError(t, stream.Complete(t.Context())) + } + got, gotErr := events.FindSessionEndEvent(t.Context(), alog, tt.sessionID) + tt.assertErr(t, gotErr) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/lib/events/sessionpostprocessing/postprocessing.go b/lib/events/sessionpostprocessing/postprocessing.go new file mode 100644 index 0000000000000..e7f2e3cd0d1d5 --- /dev/null +++ b/lib/events/sessionpostprocessing/postprocessing.go @@ -0,0 +1,80 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package sessionpostprocessing + +import ( + "context" + + "github.com/gravitational/trace" + + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" + "github.com/gravitational/teleport/lib/auth/summarizer" + "github.com/gravitational/teleport/lib/session" +) + +// Config is the configuration for the session post-processor. +type Config struct { + // SessionSummarizerProvider is a provider of the session summarizer service. + // It can be nil or provide a nil summarizer if summarization is not needed. + // The summarizer itself summarizes session recordings. + SessionSummarizerProvider *summarizer.SessionSummarizerProvider + // RecordingMetadataProvider is a provider of the recording metadata service. + RecordingMetadataProvider *recordingmetadata.Provider + // SessionEnd is the session end event to process. + SessionEnd apievents.AuditEvent + // SessionID is the ID of the session being processed. + SessionID session.ID +} + +// Process processes session end events after the session recording upload is complete. +// It summarizes the session recording and processes the recording metadata. +func Process(ctx context.Context, cfg Config) error { + switch { + case cfg.SessionSummarizerProvider == nil: + return trace.BadParameter("session summarizer provider is not set") + case cfg.RecordingMetadataProvider == nil: + return trace.BadParameter("recording metadata provider is not set") + case cfg.SessionEnd == nil: + return trace.BadParameter("session end event is not set") + case cfg.SessionID == "": + return trace.BadParameter("session ID is not set") + } + + var summarizerErr error + var metadataErr error + summarizer := cfg.SessionSummarizerProvider.SessionSummarizer() + switch o := cfg.SessionEnd.(type) { + case *apievents.SessionEnd: + if err := summarizer.SummarizeSSH(ctx, o); err != nil { + summarizerErr = trace.Wrap(err, "failed to summarize upload") + } + metadataSvc := cfg.RecordingMetadataProvider.Service() + if !o.EndTime.IsZero() && !o.StartTime.IsZero() { + duration := o.EndTime.Sub(o.StartTime) + if err := metadataSvc.ProcessSessionRecording(ctx, cfg.SessionID, duration); err != nil { + metadataErr = trace.Wrap(err, "failed to process session recording metadata") + } + } + case *apievents.DatabaseSessionEnd: + if err := summarizer.SummarizeDatabase(ctx, o); err != nil { + summarizerErr = trace.Wrap(err, "failed to summarize upload") + } + } + return trace.NewAggregate(summarizerErr, metadataErr) +} diff --git a/lib/events/sessionpostprocessing/postprocessing_test.go b/lib/events/sessionpostprocessing/postprocessing_test.go new file mode 100644 index 0000000000000..16dcfa5de5385 --- /dev/null +++ b/lib/events/sessionpostprocessing/postprocessing_test.go @@ -0,0 +1,107 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package sessionpostprocessing_test + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth/recordingmetadata" + "github.com/gravitational/teleport/lib/auth/summarizer" + "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/events/sessionpostprocessing" + "github.com/gravitational/teleport/lib/session" +) + +func TestSessionPostProcessor(t *testing.T) { + sessionID := session.ID(uuid.NewString()) + + metadataProvider := recordingmetadata.NewProvider() + recorderMetadata := &fakeRecordingMetadata{} + recorderMetadata.On( + "ProcessSessionRecording", + mock.Anything, + sessionID, + mock.Anything, + ). + Return(nil).Once() + metadataProvider.SetService(recorderMetadata) + + summarizerProvider := summarizer.NewSessionSummarizerProvider() + sessionSummarizer := &fakeSummarizer{} + sessionSummarizer.On( + "SummarizeSSH", + mock.Anything, + mock.Anything, + ).Return(nil).Once() + summarizerProvider.SetSummarizer(sessionSummarizer) + + events := eventstest.GenerateTestSession(eventstest.SessionParams{ + UserName: "alice", + SessionID: string(sessionID), + ServerID: "testcluster", + PrintData: []string{"net", "stat"}, + }) + + cfg := sessionpostprocessing.Config{ + SessionEnd: events[len(events)-1], + RecordingMetadataProvider: metadataProvider, + SessionSummarizerProvider: summarizerProvider, + SessionID: sessionID, + } + + err := sessionpostprocessing.Process(t.Context(), cfg) + require.NoError(t, err) + + recorderMetadata.AssertExpectations(t) + sessionSummarizer.AssertExpectations(t) +} + +type fakeRecordingMetadata struct { + mock.Mock +} + +func (f *fakeRecordingMetadata) ProcessSessionRecording(ctx context.Context, sessionID session.ID, duration time.Duration) error { + args := f.Called(ctx, sessionID, duration) + return args.Error(0) +} + +type fakeSummarizer struct { + mock.Mock +} + +func (f *fakeSummarizer) SummarizeSSH(ctx context.Context, sessionEndEvent *apievents.SessionEnd) error { + args := f.Called(ctx, sessionEndEvent) + return args.Error(0) +} + +func (f *fakeSummarizer) SummarizeDatabase(ctx context.Context, sessionEndEvent *apievents.DatabaseSessionEnd) error { + args := f.Called(ctx, sessionEndEvent) + return args.Error(0) +} + +func (f *fakeSummarizer) SummarizeWithoutEndEvent(ctx context.Context, sessionID session.ID) error { + args := f.Called(ctx, sessionID) + return args.Error(0) +} diff --git a/lib/service/service.go b/lib/service/service.go index bfab886d48b41..87c00e0c29da7 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2274,12 +2274,14 @@ func (process *TeleportProcess) initAuthService() error { } auditServiceConfig := events.AuditLogConfig{ - Context: process.ExitContext(), - DataDir: filepath.Join(cfg.DataDir, teleport.LogsDir), - ServerID: hostUUID, - UploadHandler: uploadHandler, - ExternalLog: externalLog, - Decrypter: encryptedIO, + Context: process.ExitContext(), + DataDir: filepath.Join(cfg.DataDir, teleport.LogsDir), + ServerID: hostUUID, + UploadHandler: uploadHandler, + ExternalLog: externalLog, + Decrypter: encryptedIO, + SessionSummarizerProvider: sessionSummarizerProvider, + RecordingMetadataProvider: recordingMetadataProvider, } auditServiceConfig.UID, auditServiceConfig.GID, err = adminCreds() if err != nil { @@ -2379,6 +2381,7 @@ func (process *TeleportProcess) initAuthService() error { Logger: logger, SessionSummarizerProvider: sessionSummarizerProvider, RecordingEncryption: recordingEncryptionManager, + RecordingMetadataProvider: recordingMetadataProvider, }, func(as *auth.Server) error { if !process.Config.CachePolicy.Enabled { return nil @@ -2401,6 +2404,7 @@ func (process *TeleportProcess) initAuthService() error { if err != nil { return trace.Wrap(err) } + authServer.EncryptedIO = encryptedIO lockWatcher, err := services.NewLockWatcher(process.ExitContext(), services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ @@ -2471,6 +2475,7 @@ func (process *TeleportProcess) initAuthService() error { recordingMetadataService, err := recordingmetadatav1.NewRecordingMetadataService(recordingmetadatav1.RecordingMetadataServiceConfig{ Streamer: authServer, UploadHandler: authServer, + Encrypter: encryptedIO, }) if err != nil { return trace.Wrap(err) @@ -2489,13 +2494,15 @@ func (process *TeleportProcess) initAuthService() error { logger.WarnContext(process.ExitContext(), "auth service's upload completer is disabled, abandoned uploads may accumulate in external storage") case uploadHandler != nil: err = events.StartNewUploadCompleter(process.ExitContext(), events.UploadCompleterConfig{ - Uploader: uploadHandler, - Component: teleport.ComponentAuth, - ClusterName: clusterName, - AuditLog: process.auditLog, - SessionTracker: authServer.Services, - Semaphores: authServer.Services, - ServerID: hostUUID, + Uploader: uploadHandler, + Component: teleport.ComponentAuth, + ClusterName: clusterName, + AuditLog: process.auditLog, + SessionTracker: authServer.Services, + Semaphores: authServer.Services, + ServerID: hostUUID, + SessionSummarizerProvider: sessionSummarizerProvider, + RecordingMetadataProvider: recordingMetadataProvider, }) if err != nil { return trace.Wrap(err, "starting upload completer")