diff --git a/api/client/client.go b/api/client/client.go index e66cd0e23c695..71a8edd00c882 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -2739,11 +2739,22 @@ func (c *Client) UploadEncryptedRecording(ctx context.Context, sessionID string, return trace.Wrap(err) } + next, stop := iter.Pull2(parts) + defer stop() + + part, err, ok := next() + if err != nil { + return trace.Wrap(err) + } else if !ok { + return trace.BadParameter("unexpected empty upload") + } + var uploadedParts []*recordingencryptionv1pb.Part // S3 requires that part numbers start at 1, so we do that by default regardless of which uploader is // configured for the auth service var partNumber int64 = 1 - for part, err := range parts { + for { + nextPart, err, hasNext := next() if err != nil { return trace.Wrap(err) } @@ -2752,11 +2763,18 @@ func (c *Client) UploadEncryptedRecording(ctx context.Context, sessionID string, Upload: createRes.Upload, PartNumber: partNumber, Part: part, + IsLast: !hasNext, }) if err != nil { return trace.Wrap(err) } uploadedParts = append(uploadedParts, uploadRes.Part) + + if !hasNext { + break + } + + part = nextPart partNumber++ } diff --git a/api/gen/proto/go/teleport/recordingencryption/v1/recording_encryption_service.pb.go b/api/gen/proto/go/teleport/recordingencryption/v1/recording_encryption_service.pb.go index 6a97aff5ac45a..a86889f0002de 100644 --- a/api/gen/proto/go/teleport/recordingencryption/v1/recording_encryption_service.pb.go +++ b/api/gen/proto/go/teleport/recordingencryption/v1/recording_encryption_service.pb.go @@ -200,7 +200,9 @@ type UploadPartRequest struct { // The ordered index applied to the part. PartNumber int64 `protobuf:"varint,2,opt,name=part_number,json=partNumber,proto3" json:"part_number,omitempty"` // The encrypted part of session recording data being uploaded. - Part []byte `protobuf:"bytes,3,opt,name=part,proto3" json:"part,omitempty"` + Part []byte `protobuf:"bytes,3,opt,name=part,proto3" json:"part,omitempty"` + // Whether this is the last upload part in the upload. + IsLast bool `protobuf:"varint,4,opt,name=is_last,json=isLast,proto3" json:"is_last,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -256,6 +258,13 @@ func (x *UploadPartRequest) GetPart() []byte { return nil } +func (x *UploadPartRequest) GetIsLast() bool { + if x != nil { + return x.IsLast + } + return false +} + // The resulting metadata about an uploaded part. type Part struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -848,12 +857,13 @@ const file_teleport_recordingencryption_v1_recording_encryption_service_proto_ra "\n" + "session_id\x18\x01 \x01(\tR\tsessionId\"W\n" + "\x14CreateUploadResponse\x12?\n" + - "\x06upload\x18\x01 \x01(\v2'.teleport.recordingencryption.v1.UploadR\x06upload\"\x89\x01\n" + + "\x06upload\x18\x01 \x01(\v2'.teleport.recordingencryption.v1.UploadR\x06upload\"\xa2\x01\n" + "\x11UploadPartRequest\x12?\n" + "\x06upload\x18\x01 \x01(\v2'.teleport.recordingencryption.v1.UploadR\x06upload\x12\x1f\n" + "\vpart_number\x18\x02 \x01(\x03R\n" + "partNumber\x12\x12\n" + - "\x04part\x18\x03 \x01(\fR\x04part\";\n" + + "\x04part\x18\x03 \x01(\fR\x04part\x12\x17\n" + + "\ais_last\x18\x04 \x01(\bR\x06isLast\";\n" + "\x04Part\x12\x1f\n" + "\vpart_number\x18\x01 \x01(\x03R\n" + "partNumber\x12\x12\n" + diff --git a/api/proto/teleport/recordingencryption/v1/recording_encryption_service.proto b/api/proto/teleport/recordingencryption/v1/recording_encryption_service.proto index 130eb2eba0144..a4590271265f3 100644 --- a/api/proto/teleport/recordingencryption/v1/recording_encryption_service.proto +++ b/api/proto/teleport/recordingencryption/v1/recording_encryption_service.proto @@ -71,6 +71,8 @@ message UploadPartRequest { int64 part_number = 2; // The encrypted part of session recording data being uploaded. bytes part = 3; + // Whether this is the last upload part in the upload. + bool is_last = 4; } // The resulting metadata about an uploaded part. diff --git a/lib/auth/recordingencryption/recordingencryptionv1/service.go b/lib/auth/recordingencryption/recordingencryptionv1/service.go index 9cf6fa6235553..695857ae5328e 100644 --- a/lib/auth/recordingencryption/recordingencryptionv1/service.go +++ b/lib/auth/recordingencryption/recordingencryptionv1/service.go @@ -156,8 +156,14 @@ func (s *Service) UploadPart(ctx context.Context, req *recordingencryptionv1.Upl return nil, trace.Wrap(err) } - part := bytes.NewReader(req.Part) - streamPart, err := s.uploader.UploadPart(ctx, upload, req.PartNumber, part) + // If upload part is not at least the minimum upload part size, append an empty part + // to pad up to the minimum upload size. + part := req.Part + if !req.IsLast && len(part) < events.MinUploadPartSizeBytes { + part = events.PadUploadPart(part, events.MinUploadPartSizeBytes) + } + + streamPart, err := s.uploader.UploadPart(ctx, upload, req.PartNumber, bytes.NewReader(part)) if err != nil { return nil, trace.Wrap(err, "uploading encrypted recording part") } diff --git a/lib/events/api.go b/lib/events/api.go index 6779b76de5a42..223fd25850a06 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -1087,7 +1087,11 @@ type MultipartUploader interface { // ReserveUploadPart reserves an upload part. Reserve is used to identify // upload errors beforehand. ReserveUploadPart(ctx context.Context, upload StreamUpload, partNumber int64) error - // UploadPart uploads part and returns the part + // UploadPart uploads part and returns the part. + // + // The part must be greater than [MinUploadPartSizeBytes]. It is the responsibility + // of the caller to add padding if needed, or else the upload may fail depending on + // storage provider. UploadPart(ctx context.Context, upload StreamUpload, partNumber int64, partBody io.ReadSeeker) (*StreamPart, error) // ListParts returns all uploaded parts for the completed upload in sorted order ListParts(ctx context.Context, upload StreamUpload) ([]StreamPart, error) diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 67fdef73ebc00..2f19850ee4953 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -631,17 +631,34 @@ func (l *AuditLog) UploadEncryptedRecording(ctx context.Context, sessionID strin return trace.Wrap(err, "creating upload") } + next, stop := iter.Pull2(parts) + defer stop() + + part, err, ok := next() + if err != nil { + return trace.Wrap(err) + } else if !ok { + return trace.BadParameter("unexpected empty upload") + } + var streamParts []StreamPart // S3 requires that part numbers start at 1, so we do that by default regardless of which uploader is // configured for the auth service var partNumber int64 = 1 - for part, err := range parts { + for { + if err := l.UploadHandler.ReserveUploadPart(ctx, *upload, partNumber); err != nil { + return trace.Wrap(err, "reserving upload part") + } + + nextPart, err, hasNext := next() if err != nil { return trace.Wrap(err) } - if err := l.UploadHandler.ReserveUploadPart(ctx, *upload, partNumber); err != nil { - return trace.Wrap(err, "reserving upload part") + // If the upload part is not at least the minimum upload part size, and this isn't + // the last part, add padding to meet the minimum upload size. + if hasNext && len(part) < MinUploadPartSizeBytes { + part = PadUploadPart(part, MinUploadPartSizeBytes) } streamPart, err := l.UploadHandler.UploadPart(ctx, *upload, partNumber, bytes.NewReader(part)) @@ -649,6 +666,12 @@ func (l *AuditLog) UploadEncryptedRecording(ctx context.Context, sessionID strin return trace.Wrap(err, "uploading part") } streamParts = append(streamParts, *streamPart) + + if !hasNext { + break + } + + part = nextPart partNumber++ } @@ -771,3 +794,20 @@ func sessionStartCallbackFromContext(ctx context.Context) (SessionStartCallback, return cb, nil } + +// PadUploadPart adds padding to the given upload part to reach the minimum size. +func PadUploadPart(uploadPart []byte, minSize int) []byte { + // Create padding to reach the target size. Note that the padding cannot + // be shorter than the header size. + paddingBytes := max(minSize-len(uploadPart), ProtoStreamV2PartHeaderSize) + paddedPart := make([]byte, paddingBytes) + + paddedPartHeader := PartHeader{ + ProtoVersion: ProtoStreamV2, + PaddingSize: uint64(paddingBytes - ProtoStreamV2PartHeaderSize), + PartSize: 0, + } + copy(paddedPart, paddedPartHeader.Bytes()) + + return append(uploadPart, paddedPart...) +} diff --git a/lib/events/auditlog_test.go b/lib/events/auditlog_test.go index a9a60933a4fd6..a1e2d4c65daa7 100644 --- a/lib/events/auditlog_test.go +++ b/lib/events/auditlog_test.go @@ -19,6 +19,7 @@ package events_test import ( + "bytes" "context" "encoding/json" "errors" @@ -372,3 +373,51 @@ func makeLog(t *testing.T, clock clockwork.Clock) *events.AuditLog { return alog } + +func TestPadUploadPart(t *testing.T) { + partData := bytes.Repeat([]byte{1, 2, 3}, 10) + partHeader := events.PartHeader{ + ProtoVersion: events.V2, + PartSize: uint64(len(partData)), + PaddingSize: 0, + } + headerBytes := partHeader.Bytes() + part := append(headerBytes, partData...) + + // Pad the upload part to double the size. + minSize := len(part) * 2 + paddedPart := events.PadUploadPart(part, minSize) + require.Len(t, paddedPart, minSize) + + // Padding the upload part again with the same minimum should add a single header in size. + paddedPart = events.PadUploadPart(paddedPart, minSize) + require.Len(t, paddedPart, minSize+events.ProtoStreamV2PartHeaderSize) + + // Ensure we can read out each part. + r := bytes.NewReader(paddedPart) + h1, err := events.ParsePartHeader(r) + require.NoError(t, err) + require.Equal(t, partHeader, h1) + gotData, err := io.ReadAll(io.LimitReader(r, int64(h1.PartSize))) + require.NoError(t, err) + require.Equal(t, partData, gotData) + io.Copy(io.Discard, io.LimitReader(r, int64(h1.PaddingSize))) + + h2, err := events.ParsePartHeader(r) + require.NoError(t, err) + require.Equal(t, events.PartHeader{ + ProtoVersion: events.V2, + PaddingSize: uint64(len(part) - events.ProtoStreamV2PartHeaderSize), + }, h2) + io.Copy(io.Discard, io.LimitReader(r, int64(h2.PaddingSize))) + + h3, err := events.ParsePartHeader(r) + require.NoError(t, err) + require.Equal(t, events.PartHeader{ + ProtoVersion: events.V2, + PaddingSize: 0, + }, h3) + + _, err = r.Read(nil) + require.ErrorIs(t, err, io.EOF) +} diff --git a/lib/events/azsessions/azsessions.go b/lib/events/azsessions/azsessions.go index f7b2a206664a4..d8d65a5975245 100644 --- a/lib/events/azsessions/azsessions.go +++ b/lib/events/azsessions/azsessions.go @@ -523,7 +523,7 @@ func (*Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUploa func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64, partBody io.ReadSeeker) (*events.StreamPart, error) { partBlob := h.partBlob(upload, partNumber) - // our parts are just over 5 MiB (events.MinUploadPartSizeBytes) so we can + // our parts are just over 5 MiB [events.MinUploadPartSizeBytes] so we can // upload them in one shot response, err := cErr(partBlob.Upload(ctx, streaming.NopCloser(partBody), nil)) if err != nil { diff --git a/lib/events/eventstest/generate.go b/lib/events/eventstest/generate.go index 60058a8b3a32b..abafb943174c1 100644 --- a/lib/events/eventstest/generate.go +++ b/lib/events/eventstest/generate.go @@ -37,6 +37,8 @@ import ( // for generated session type SessionParams struct { // PrintEvents sets up print events count. Ignored if PrintData is set. + // The size of the resulting event stream varies due to compression, but with + // a sufficiently large number of events results in approximately 64 bytes per event. PrintEvents int64 // PrintData is optional data to use for print events. Each element of the // slice represents data for one print event. diff --git a/lib/events/eventstest/uploader.go b/lib/events/eventstest/uploader.go index 5164d9601f7b0..ead8d42087b0b 100644 --- a/lib/events/eventstest/uploader.go +++ b/lib/events/eventstest/uploader.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "io" + "iter" "sort" "sync" "time" @@ -34,9 +35,18 @@ import ( "github.com/gravitational/teleport/lib/session" ) +// MemoryUploaderConfig optional configuration for MemoryUploader. +type MemoryUploaderConfig struct { + // EventsC is used by some tests to receive signal for completed uploads. + EventsC chan events.UploadEvent + // MinimumUploadBytes sets the minimum upload part size. The uploader will + // add padding to smaller uploads to reach this minimum size. + MinimumUploadBytes int +} + // NewMemoryUploader returns a new memory uploader implementing multipart // upload -func NewMemoryUploader(eventsC ...chan events.UploadEvent) *MemoryUploader { +func NewMemoryUploader(cfg ...MemoryUploaderConfig) *MemoryUploader { up := &MemoryUploader{ mtx: &sync.RWMutex{}, uploads: make(map[string]*MemoryUpload), @@ -45,21 +55,22 @@ func NewMemoryUploader(eventsC ...chan events.UploadEvent) *MemoryUploader { metadata: make(map[session.ID][]byte), thumbnails: make(map[session.ID][]byte), } - if len(eventsC) != 0 { - up.eventsC = eventsC[0] + if len(cfg) != 0 { + up.cfg = cfg[0] } return up } // MemoryUploader uploads all bytes to memory, used in tests type MemoryUploader struct { + cfg MemoryUploaderConfig + mtx *sync.RWMutex uploads map[string]*MemoryUpload sessions map[session.ID][]byte summaries map[session.ID][]byte metadata map[session.ID][]byte thumbnails map[session.ID][]byte - eventsC chan events.UploadEvent // Clock is an optional [clockwork.Clock] to determine the time to associate // with uploads and parts. @@ -87,11 +98,11 @@ type part struct { } func (m *MemoryUploader) trySendEvent(event events.UploadEvent) { - if m.eventsC == nil { + if m.cfg.EventsC == nil { return } select { - case m.eventsC <- event: + case m.cfg.EventsC <- event: default: } } @@ -404,6 +415,64 @@ func (m *MemoryUploader) ReserveUploadPart(ctx context.Context, upload events.St return nil } +// UploadEncryptedRecording uploads encrypted recordings. +func (m *MemoryUploader) UploadEncryptedRecording(ctx context.Context, sessionID string, parts iter.Seq2[[]byte, error]) error { + sessID, err := session.ParseID(sessionID) + if err != nil { + return trace.Wrap(err) + } + upload, err := m.CreateUpload(ctx, *sessID) + if err != nil { + return trace.Wrap(err, "creating upload") + } + + next, stop := iter.Pull2(parts) + defer stop() + + part, err, ok := next() + if err != nil { + return trace.Wrap(err) + } else if !ok { + return trace.BadParameter("unexpected empty upload") + } + + var streamParts []events.StreamPart + // S3 requires that part numbers start at 1, so we do that by default regardless of which uploader is + // configured for the auth service + var partNumber int64 = 1 + for { + if err := m.ReserveUploadPart(ctx, *upload, partNumber); err != nil { + return trace.Wrap(err, "reserving upload part") + } + + nextPart, err, hasNext := next() + if err != nil { + return trace.Wrap(err) + } + + // If the upload part is not at least the minimum upload part size, and this isn't + // the last part, append an empty part to pad up to the minimum upload size. + if hasNext && len(part) < m.cfg.MinimumUploadBytes { + part = events.PadUploadPart(part, m.cfg.MinimumUploadBytes) + } + + streamPart, err := m.UploadPart(ctx, *upload, partNumber, bytes.NewReader(part)) + if err != nil { + return trace.Wrap(err, "uploading part") + } + streamParts = append(streamParts, *streamPart) + + if !hasNext { + break + } + + part = nextPart + partNumber++ + } + + return trace.Wrap(m.CompleteUpload(ctx, *upload, streamParts), "completing upload") +} + // MockUploader is a limited implementation of [events.MultipartUploader] that // allows injecting errors for testing purposes. [MemoryUploader] is a more // complete implementation and should be preferred for testing the happy path. diff --git a/lib/events/filesessions/fileasync.go b/lib/events/filesessions/fileasync.go index e67d03a59fe14..091e2dca739a1 100644 --- a/lib/events/filesessions/fileasync.go +++ b/lib/events/filesessions/fileasync.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "io" + "iter" "log/slog" "os" "path/filepath" @@ -66,6 +67,14 @@ type UploaderConfig struct { Component string // EncryptedRecordingUploader uploads encrypted session recordings EncryptedRecordingUploader events.EncryptedRecordingUploader + // EncryptedRecordingUploadTargetSize is the target size used when aggregating + // encrypted recording parts before sending them to EncryptedRecordingUploader. + // Encrypted uploads should slightly exceed this target size unless limited by the maximum. + EncryptedRecordingUploadTargetSize int + // EncryptedRecordingUploadTargetSize is the maximum size used when aggregating + // encrypted recording parts before sending them to EncryptedRecordingUploader. + // If set to 0, then no maximum is enforced. + EncryptedRecordingUploadMaxSize int } // CheckAndSetDefaults checks and sets default values of UploaderConfig @@ -91,6 +100,9 @@ func (cfg *UploaderConfig) CheckAndSetDefaults() error { if cfg.Component == "" { cfg.Component = teleport.ComponentUpload } + if cfg.EncryptedRecordingUploadTargetSize == 0 { + cfg.EncryptedRecordingUploadTargetSize = events.MinUploadPartSizeBytes + } return nil } @@ -385,55 +397,126 @@ func (u *Uploader) uploadEncryptedRecording(ctx context.Context, sessionID strin return trace.Wrap(errSkipEncryptedUpload, "no encrypted uploader configured") } - partIter := func(yield func([]byte, error) bool) { - var buf bytes.Buffer - for { + // The upload parts in the given reader are each ~128KB. Usually these parts are consumed and reconstructed + // by Auth in 5MB chunks to meet the minimum upload size of upload providers like S3. Since these uploads + // are proxied directly to the uploader from the agent here (see link below), this agent needs to combine + // these upload parts into larger, aggregated upload parts. + // + // https://github.com/gravitational/teleport/blob/master/rfd/0127-encrypted-session-recordings.md#session-recording-modes + partIter := encryptedUploadAggregateIter(in, u.cfg.EncryptedRecordingUploadTargetSize, u.cfg.EncryptedRecordingUploadMaxSize) + + u.log.DebugContext(ctx, "uploading encrypted recording", "session_id", sessionID) + if err := u.cfg.EncryptedRecordingUploader.UploadEncryptedRecording(ctx, sessionID, partIter); err != nil { + return trace.Wrap(err) + } + + return nil +} + +// encryptedUploadAggregateIter returns an iterator that aggregates upload parts from the given reader +// into larger upload with size greater than targetSize, or as near targetSize as possible without exceeding +// the maxSize. +func encryptedUploadAggregateIter(in io.Reader, targetSize int, maxSize int) iter.Seq2[[]byte, error] { + // buf holds aggregated upload parts that will be uploaded as a single upload part. + var buf bytes.Buffer + + readNextPartHeader := func() (events.PartHeader, error) { + header, err := events.ParsePartHeader(in) + if err != nil { + return events.PartHeader{}, trace.Wrap(err) + } + + if header.Flags&events.ProtoStreamFlagEncrypted == 0 { + return events.PartHeader{}, trace.Wrap(errSkipEncryptedUpload, "recording not encrypted") + } + + // Ensure that the individual file upload parts are not larger than the max size allowed here (e.g. 4MB gRPC max message size). + // This error case should never be hit outside of tests, but we want to ensure we fail fast in case a bug ever arises here. + totalPartSize := len(header.Bytes()) + int(header.PartSize) + if maxSize != 0 && totalPartSize > maxSize { + return events.PartHeader{}, trace.BadParameter("encrypted upload part is larger than the maximum size, so it cannot be uploaded. This is a bug.") + } + + return header, nil + } + + writePartToBuffer := func(header events.PartHeader) error { + // We are going to discard any padding as it isn't necessary within the individual parts. + originalPaddingSize := header.PaddingSize + header.PaddingSize = 0 + + if _, err := buf.Write(header.Bytes()); err != nil { + return trace.Wrap(err) + } + + // Copy the part into the buffer. + reader := io.LimitReader(in, int64(header.PartSize)) + copied, err := io.Copy(&buf, reader) + if err != nil && !errors.Is(err, io.EOF) { + return trace.Wrap(err) + } + + if copied != int64(header.PartSize) { + return trace.Errorf("copied %d bytes from recording part instead of expected %d", copied, int64(header.PartSize)) + } + + // Discard the padding. + discarded, err := io.Copy(io.Discard, io.LimitReader(in, int64(originalPaddingSize))) + if err != nil && !errors.Is(err, io.EOF) { + return trace.Wrap(err) + } + + if discarded != int64(originalPaddingSize) { + return trace.Errorf("discarded %d padding bytes from recording part instead of expected %d", copied, int64(originalPaddingSize)) + } + + return nil + } + + return func(yield func([]byte, error) bool) { + // yield the current aggregated upload part and reset the buffer. + yieldCurrent := func() bool { + // Copy the buffer to a new []byte so that the next + // iteration doesn't wipe the previous yielded bytes. + bytes := make([]byte, buf.Len()) + copy(bytes, buf.Bytes()) buf.Reset() - header, err := events.ParsePartHeader(in) + return yield(bytes, nil) + } + + for { + partHeader, err := readNextPartHeader() if err != nil { if errors.Is(err, io.EOF) { - break + // No parts remaining, yield the current part and return. + yieldCurrent() + return } yield(nil, trace.Wrap(err)) return } - if header.Flags&events.ProtoStreamFlagEncrypted == 0 { - yield(nil, trace.Wrap(errSkipEncryptedUpload, "recording not encrypted")) - return - } - - if _, err := buf.Write(header.Bytes()); err != nil { - yield(nil, trace.Wrap(err)) - return - } - - totalPartSize := int64(header.PartSize + header.PaddingSize) - reader := io.LimitReader(in, totalPartSize) - copied, err := io.Copy(&buf, reader) - if err != nil && !errors.Is(err, io.EOF) { - yield(nil, trace.Wrap(err)) - return + // If a max size is configured and the aggregate buffer is not empty, check if there is + // room to add this upload part. If not, yield the current aggregate before continuing. + totalPartSize := len(partHeader.Bytes()) + int(partHeader.PartSize) + if maxSize != 0 && buf.Len() > 0 && buf.Len()+totalPartSize > maxSize { + if !yieldCurrent() { + return + } } - if copied != totalPartSize { - yield(nil, trace.Errorf("copied %d bytes of recording part instead of expected %d", copied, totalPartSize)) - return - } + writePartToBuffer(partHeader) - if !yield(buf.Bytes(), nil) { - return + // If we've reached the target upload size, yield the current + // aggregated upload part before continuing. + if buf.Len() > targetSize { + if !yieldCurrent() { + return + } } } } - - u.log.DebugContext(ctx, "uploading encrypted recording", "session_id", sessionID) - if err := u.cfg.EncryptedRecordingUploader.UploadEncryptedRecording(ctx, sessionID, partIter); err != nil { - return trace.Wrap(err) - } - - return nil } func (u *Uploader) startUpload(ctx context.Context, fileName string) (err error) { @@ -502,6 +585,13 @@ func (u *Uploader) startUpload(ctx context.Context, fileName string) (err error) if err := u.uploadEncryptedRecording(ctx, sessionID.String(), sessionFile); !errors.Is(err, errSkipEncryptedUpload) { if err != nil { + log.WarnContext(ctx, "Encrypted upload failed.", "error", err) + u.emitEvent(events.UploadEvent{ + SessionID: sessionID.String(), + Error: sessionError{err}, + Created: u.cfg.Clock.Now().UTC(), + }) + return trace.Wrap(err) } diff --git a/lib/events/filesessions/fileasync_chaos_test.go b/lib/events/filesessions/fileasync_chaos_test.go index 17afa3609ce94..a9a7ef18890b8 100644 --- a/lib/events/filesessions/fileasync_chaos_test.go +++ b/lib/events/filesessions/fileasync_chaos_test.go @@ -54,7 +54,9 @@ func TestChaosUpload(t *testing.T) { defer cancel() eventsC := make(chan events.UploadEvent, 100) - memUploader := eventstest.NewMemoryUploader(eventsC) + memUploader := eventstest.NewMemoryUploader(eventstest.MemoryUploaderConfig{ + EventsC: eventsC, + }) streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ Uploader: memUploader, MinUploadBytes: 1024, @@ -122,7 +124,9 @@ func TestChaosUpload(t *testing.T) { go uploader.Serve(ctx) defer uploader.Close() - fileStreamer, err := NewStreamer(scanDir, nil) + fileStreamer, err := NewStreamer(StreamerConfig{ + Dir: scanDir, + }) require.NoError(t, err) parallelStreams := 20 diff --git a/lib/events/filesessions/fileasync_test.go b/lib/events/filesessions/fileasync_test.go index f26e585fc33c8..4d95abbe729ed 100644 --- a/lib/events/filesessions/fileasync_test.go +++ b/lib/events/filesessions/fileasync_test.go @@ -22,9 +22,17 @@ import ( "bytes" "context" "crypto/rand" + "encoding/hex" "errors" + "io" + "io/fs" + "iter" + mathrand "math/rand/v2" "os" "path/filepath" + "slices" + "strconv" + "strings" "sync/atomic" "testing" "time" @@ -41,19 +49,18 @@ import ( // TestUploadOK tests async file uploads scenarios func TestUploadOK(t *testing.T) { - p := newUploaderPack(t, nil) - defer p.Close(t) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() - // wait until uploader blocks on the clock - p.clock.BlockUntil(1) + p := newUploaderPack(ctx, t, uploaderPackConfig{}) - fileStreamer, err := NewStreamer(p.scanDir, nil) - require.NoError(t, err) + // wait until uploader blocks on the clock + p.clock.BlockUntilContext(ctx, 1) inEvents := eventstest.GenerateTestSession(eventstest.SessionParams{PrintEvents: 1024}) sid := inEvents[0].(events.SessionMetadataGetter).GetSessionID() - emitStream(p.ctx, t, fileStreamer, inEvents) + p.emitEvents(ctx, t, inEvents) // initiate the scan by advancing clock past // block period @@ -64,34 +71,33 @@ func TestUploadOK(t *testing.T) { case event = <-p.memEventsC: require.Equal(t, event.SessionID, sid) require.NoError(t, event.Error) - case <-p.ctx.Done(): + case <-ctx.Done(): t.Fatalf("Timeout waiting for async upload, try `go test -v` to get more logs for details") } // read the upload and make sure the data is equal - outEvents := readStream(p.ctx, t, event.UploadID, p.memUploader) + outEvents := p.readEvents(ctx, t, event.UploadID) require.Equal(t, inEvents, outEvents) } // TestUploadParallel verifies several parallel uploads that have to wait // for semaphore func TestUploadParallel(t *testing.T) { - p := newUploaderPack(t, nil) - defer p.Close(t) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + p := newUploaderPack(ctx, t, uploaderPackConfig{}) // wait until uploader blocks on the clock - p.clock.BlockUntil(1) + p.clock.BlockUntilContext(ctx, 1) sessions := make(map[string][]apievents.AuditEvent) for range 5 { - fileStreamer, err := NewStreamer(p.scanDir, nil) - require.NoError(t, err) - sessionEvents := eventstest.GenerateTestSession(eventstest.SessionParams{PrintEvents: 1024}) sid := sessionEvents[0].(events.SessionMetadataGetter).GetSessionID() - emitStream(p.ctx, t, fileStreamer, sessionEvents) + p.emitEvents(ctx, t, sessionEvents) sessions[sid] = sessionEvents } @@ -108,12 +114,12 @@ func TestUploadParallel(t *testing.T) { require.NoError(t, event.Error) sessionEvents, found = sessions[event.SessionID] require.True(t, found, "session %q is not expected, possible duplicate event", event.SessionID) - case <-p.ctx.Done(): + case <-ctx.Done(): t.Fatalf("Timeout waiting for async upload, try `go test -v` to get more logs for details") } // read the upload and make sure the data is equal - outEvents := readStream(p.ctx, t, event.UploadID, p.memUploader) + outEvents := p.readEvents(ctx, t, event.UploadID) require.Equal(t, sessionEvents, outEvents) @@ -150,7 +156,7 @@ func TestMovesCorruptedUploads(t *testing.T) { require.NoError(t, err) require.NoError(t, badFile.Close()) - stats, err := uploader.Scan(context.Background()) + stats, err := uploader.Scan(t.Context()) require.NoError(t, err) require.Equal(t, 2, stats.Scanned) require.Equal(t, 2, stats.Corrupted) @@ -165,7 +171,7 @@ func TestMovesCorruptedUploads(t *testing.T) { // run a second scan to verify that: // 1. the corrupted file is no longer processed // 2. the file with the bad name was still flagged as corrupted - stats, err = uploader.Scan(context.Background()) + stats, err = uploader.Scan(t.Context()) require.NoError(t, err) require.Equal(t, 1, stats.Scanned) require.Equal(t, 1, stats.Corrupted) @@ -363,39 +369,40 @@ func TestUploadBackoff(t *testing.T) { var terminateConnectionAt atomic.Int64 terminateConnectionAt.Store(700) - p := newUploaderPack(t, func(streamer events.Streamer) (events.Streamer, error) { - return events.NewCallbackStreamer(events.CallbackStreamerConfig{ - Inner: streamer, - OnRecordEvent: func(ctx context.Context, sid session.ID, pe apievents.PreparedSessionEvent) error { - event := pe.GetAuditEvent() - terminateAt := terminateConnectionAt.Load() - if terminateAt > 0 && event.GetIndex() >= terminateAt { - t.Logf("Terminating connection at event %v", event.GetIndex()) - return trace.ConnectionProblem(nil, "connection terminated at event index %v", terminateAt) - } - return nil - }, - }) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + p := newUploaderPack(ctx, t, uploaderPackConfig{ + wrapProtoStreamer: func(streamer events.Streamer) (events.Streamer, error) { + return events.NewCallbackStreamer(events.CallbackStreamerConfig{ + Inner: streamer, + OnRecordEvent: func(ctx context.Context, sid session.ID, pe apievents.PreparedSessionEvent) error { + event := pe.GetAuditEvent() + terminateAt := terminateConnectionAt.Load() + if terminateAt > 0 && event.GetIndex() >= terminateAt { + t.Logf("Terminating connection at event %v", event.GetIndex()) + return trace.ConnectionProblem(nil, "connection terminated at event index %v", terminateAt) + } + return nil + }, + }) + }, }) - defer p.Close(t) // wait until uploader blocks on the clock before creating the stream - p.clock.BlockUntil(1) - - fileStreamer, err := NewStreamer(p.scanDir, nil) - require.NoError(t, err) + p.clock.BlockUntilContext(ctx, 1) inEvents := eventstest.GenerateTestSession(eventstest.SessionParams{PrintEvents: 4096}) sid := inEvents[0].(events.SessionMetadataGetter).GetSessionID() - stream, err := fileStreamer.CreateAuditStream(p.ctx, session.ID(sid)) + stream, err := p.fileStreamer.CreateAuditStream(ctx, session.ID(sid)) require.NoError(t, err) for _, event := range inEvents { - err := stream.RecordEvent(p.ctx, eventstest.PrepareEvent(event)) + err := stream.RecordEvent(ctx, eventstest.PrepareEvent(event)) require.NoError(t, err) } - err = stream.Complete(p.ctx) + err = stream.Complete(ctx) require.NoError(t, err) // initiate the scan by advancing clock past @@ -418,14 +425,14 @@ func TestUploadBackoff(t *testing.T) { diffs = append(diffs, event.Created.Sub(prev)) prev = event.Created } - case <-p.ctx.Done(): + case <-ctx.Done(): t.Fatalf("Timeout waiting for async upload %v, try `go test -v` to get more logs for details", i) } // Block until Scan has been called two times, // first time after doing the scan, and second // on receiving the event to <- eventsCh - p.clock.BlockUntil(2) + p.clock.BlockUntilContext(ctx, 2) p.clock.Advance(p.scanPeriod*time.Duration(i+2) + time.Second) } @@ -438,12 +445,12 @@ func TestUploadBackoff(t *testing.T) { // Fix the streamer, make sure the upload succeeds terminateConnectionAt.Store(0) - p.clock.BlockUntil(2) + p.clock.BlockUntilContext(ctx, 2) p.clock.Advance(time.Hour) select { case event := <-p.eventsC: require.NoError(t, event.Error) - case <-p.ctx.Done(): + case <-ctx.Done(): t.Fatalf("Timeout waiting for async upload, try `go test -v` to get more logs for details") } } @@ -451,13 +458,13 @@ func TestUploadBackoff(t *testing.T) { // TestUploadBadSession creates a corrupted session file // and makes sure the uploader marks it as faulty func TestUploadBadSession(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() - p := newUploaderPack(t, nil) - defer p.Close(t) + p := newUploaderPack(ctx, t, uploaderPackConfig{}) // wait until uploader blocks on the clock - p.clock.BlockUntil(1) + p.clock.BlockUntilContext(ctx, 1) sessionID := session.NewID() fileName := filepath.Join(p.scanDir, string(sessionID)+tarExt) @@ -476,7 +483,7 @@ func TestUploadBadSession(t *testing.T) { case event = <-p.eventsC: require.Error(t, event.Error) require.True(t, isSessionError(event.Error)) - case <-p.ctx.Done(): + case <-ctx.Done(): t.Fatalf("Timeout waiting for async upload, try `go test -v` to get more logs for details") } @@ -487,80 +494,417 @@ func TestUploadBadSession(t *testing.T) { require.Equal(t, 0, stats.Started) } +// TestMinimumUpload tests that the minimum upload values for files and final uploads are respected. +func TestMinimumUpload(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + // The gzip writer ensures that upload parts exceed the minimum upload size, and imprecisely + // constrains the maximum size to 133% of the minimum. At small values, this is especially + // imprecise, so we give a full 200% of the minimum as overhead. + // + // Usually, we use 128KB for file parts and 5KB for final upload parts. + minFileBytes := 8192 + maxFileBytes := 2 * minFileBytes + minUploadBytes := 33768 + maxUploadBytes := 2 * minUploadBytes + + p := newUploaderPack(ctx, t, uploaderPackConfig{ + minimumFileUploadBytes: int64(minFileBytes), + minimumUploadBytes: int64(minUploadBytes), + }) + + // wait until uploader blocks on the clock + p.clock.BlockUntilContext(ctx, 1) + + inEvents := eventstest.GenerateTestSession(eventstest.SessionParams{PrintEvents: int64(mathrand.IntN(5000) + 5000)}) + sid := inEvents[0].(events.SessionMetadataGetter).GetSessionID() + stream, err := p.fileStreamer.CreateAuditStream(ctx, session.ID(sid)) + require.NoError(t, err) + for _, event := range inEvents { + err := stream.RecordEvent(ctx, eventstest.PrepareEvent(event)) + require.NoError(t, err) + } + + // Before completing the file stream, which writes upload part files into a .tar, check that + // each file part was within specific size expectations. + var partFiles []fs.FileInfo + err = filepath.Walk(p.scanDir, func(path string, info fs.FileInfo, err error) error { + // skip non-part files. + if ext := filepath.Ext(info.Name()); ext == ".part" { + partFiles = append(partFiles, info) + } + return nil + }) + require.NoError(t, err) + + slices.SortFunc(partFiles, func(a fs.FileInfo, b fs.FileInfo) int { + aPartNumberStr, _ := strings.CutSuffix(a.Name(), ".part") + aPartNumber, err := strconv.Atoi(aPartNumberStr) + require.NoError(t, err) + bPartNumberStr, _ := strings.CutSuffix(b.Name(), ".part") + bPartNumber, err := strconv.Atoi(bPartNumberStr) + require.NoError(t, err) + return aPartNumber - bPartNumber + }) + + for i, partFile := range partFiles { + partSize := int(partFile.Size()) + require.GreaterOrEqual(t, partSize, minFileBytes, "expected upload part %v to be between %v and %v bytes, but was %v bytes", i, minFileBytes, maxFileBytes, partSize) + require.LessOrEqual(t, partSize, maxFileBytes, "expected upload part %v to be between %v and %v bytes, but was %v bytes", i, minFileBytes, maxFileBytes, partSize) + } + + // Complete the file stream and advance the clock to unblock the uploader scanner. + err = stream.Complete(ctx) + require.NoError(t, err) + p.clock.Advance(p.scanPeriod + time.Second) + + var event events.UploadEvent + select { + case event = <-p.memEventsC: + require.Equal(t, event.SessionID, sid) + require.NoError(t, event.Error) + case <-ctx.Done(): + t.Fatalf("Timeout waiting for async upload, try `go test -v` to get more logs for details") + } + + // read the upload and make sure the data is equal + outEvents := p.readEvents(ctx, t, event.UploadID) + require.Equal(t, inEvents, outEvents) + + uploadParts, err := p.memUploader.GetParts(event.UploadID) + require.NoError(t, err) + + // minimumProtoUploadBytes should ensure each upload part is within specific size expectations. + for i, part := range uploadParts { + if i == len(uploadParts)-1 { + // The last part is not required to meet the minimum size. + require.LessOrEqual(t, len(part), maxUploadBytes, "expected last upload part to be smaller than %v bytes, but was %v bytes", maxUploadBytes, len(part)) + } else { + require.GreaterOrEqual(t, len(part), minUploadBytes, "expected upload part %v to be between %v and %v bytes, but was %v bytes", i, minUploadBytes, maxUploadBytes, len(part)) + require.LessOrEqual(t, len(part), maxUploadBytes, "expected upload part %v to be between %v and %v bytes, but was %v bytes", i, minUploadBytes, maxUploadBytes, len(part)) + } + } + + // There should be at least 1 final upload part for every 4 file upload parts. + minFactor := minUploadBytes / minFileBytes + require.LessOrEqual(t, len(partFiles)/minFactor, len(uploadParts), "expected there to be 1 final upload part for every 4 transient file parts, but got %v and %v respectively", len(uploadParts), len(partFiles)) +} + +func TestUploadEncryptedRecording(t *testing.T) { + for _, tc := range []struct { + name string + minFileSize int + // encrypted upload files should be aggregated by the encrypted uploader to reach the target size. + encryptedTargetSize int + encryptedMaxSize int + expectEncryptedUploadErr bool + // uploads from the encrypted uploader should be padded further by the final uploader to + // reach the minimum upload size. e.g. pad encrypted uploads of 4MB (max from GRPC) to reach + // S3 minimum upload of 5MB. + minUploadBytes int + }{ + { + name: "target size larger than max encrypted upload size", + minFileSize: 8192, + encryptedTargetSize: 8192 * 4 * 2, + encryptedMaxSize: 8192 * 4, + }, { + name: "target size equals max encrypted upload size", + minFileSize: 8192, + encryptedTargetSize: 8192 * 4, + encryptedMaxSize: 8192 * 4, + }, { + name: "target size smaller than max encrypted upload size", + minFileSize: 8192, + encryptedTargetSize: 8192 * 4, + encryptedMaxSize: 8192 * 4 * 2, + }, { + name: "target size larger than min upload size", + minFileSize: 8192, + encryptedTargetSize: 8192 * 4 * 2, + minUploadBytes: 8192 * 4, + }, { + name: "target size equals min upload size", + minFileSize: 8192, + encryptedTargetSize: 8192 * 4, + minUploadBytes: 8192 * 4, + }, { + name: "target size smaller than min upload size", + minFileSize: 8192, + encryptedTargetSize: 8192 * 4, + minUploadBytes: 8192 * 4 * 2, + }, { + name: "min file size larger than max encrypted size", + minFileSize: 8192 * 4, + encryptedMaxSize: 8192, + expectEncryptedUploadErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + // First we calculate some expected size values for different steps in the event stream. + // It is easier to calculate them here rather than in the test case's parameters. + + // the gzip writer is imprecise and may exceed the min file size. With a min of 8192, + // an extra 8192 bytes should be plenty of overhead. If this test becomes flaky, consider + // raising the minimum in the test cases above. + maxFileSize := tc.minFileSize * 2 + + // We expect encrypted uploads to exceed the target size, unless the maximum prevents it, + // in which case we will be short one file part. + expectEncryptedSizeFloor := tc.encryptedTargetSize + if tc.encryptedMaxSize != 0 && tc.encryptedTargetSize+maxFileSize > tc.encryptedMaxSize { + expectEncryptedSizeFloor = tc.encryptedMaxSize - maxFileSize + } + + // We expect encrypted uploads to be below the max size if set. Otherwise, we expect it to + // be the target size + one additional file part. + expectEncryptedSizeCeil := tc.encryptedMaxSize + if tc.encryptedMaxSize == 0 { + expectEncryptedSizeCeil = tc.encryptedTargetSize + maxFileSize + } + + // the final upload is never smaller than the encrypted size, but always above the minimum. + expectFinalSizeFloor := max(tc.minUploadBytes, expectEncryptedSizeFloor) + + // the final upload is only larger than the minimum (+header_size) if the encrypted size is larger. + expectFinalSizeCeil := max(expectEncryptedSizeCeil, tc.minUploadBytes+events.ProtoStreamV2PartHeaderSize) + + // Create a wrapper around the encrypted uploader to ensure the caller is yielding + // correctly sized uploads. + var recollectParts [][]byte + encryptedUploadWrapper := wrapEncryptedUploaderFn(func(u events.EncryptedRecordingUploader) events.EncryptedRecordingUploader { + return encryptedUploaderFn(func(ctx context.Context, sessionID string, parts iter.Seq2[[]byte, error]) error { + next, stop := iter.Pull2(parts) + defer stop() + + part, err, ok := next() + if err != nil { + return trace.Wrap(err) + } else if !ok { + return trace.BadParameter("unexpected empty upload") + } + + for { + recollectParts = append(recollectParts, part) + + nextPart, err, hasNext := next() + if err != nil { + return trace.Wrap(err) + } + + if !hasNext { + break + } + + if hasNext { + require.GreaterOrEqual(t, len(part), expectEncryptedSizeFloor, "expected encrypted upload to be between %v and %v bytes, but was %v bytes", expectEncryptedSizeFloor, expectEncryptedSizeCeil, len(part)) + require.LessOrEqual(t, len(part), expectEncryptedSizeCeil, "expected encrypted upload to be between %v and %v bytes, but was %v bytes", expectEncryptedSizeFloor, expectEncryptedSizeCeil, len(part)) + } else { + // The last part is not expected to meet the target size. + require.LessOrEqual(t, len(part), expectEncryptedSizeCeil, "expected last encrypted upload to be smaller than %v bytes, but was %v bytes", expectEncryptedSizeCeil, len(part)) + } + + part = nextPart + } + + partReIter := func(yield func([]byte, error) bool) { + for _, part := range recollectParts { + if !yield(part, nil) { + return + } + } + } + + return u.UploadEncryptedRecording(ctx, sessionID, partReIter) + }) + }) + + p := newUploaderPack(ctx, t, uploaderPackConfig{ + minimumFileUploadBytes: int64(tc.minFileSize), + minimumUploadBytes: int64(tc.minUploadBytes), + encrypter: &fakeEncryptedIO{}, + wrapEncryptedUploader: encryptedUploadWrapper, + encryptedRecordingUploadTargetSize: tc.encryptedTargetSize, + encryptedRecordingUploadMaxSize: tc.encryptedMaxSize, + }) + + // wait until uploader blocks on the clock + err := p.clock.BlockUntilContext(ctx, 1) + require.NoError(t, err) + + // Here we ensure at least 5 final upload parts so that we amply test the test case values, + some variance. + eventsCount := expectFinalSizeCeil*5/64 + mathrand.IntN(1000) + inEvents := eventstest.GenerateTestSession(eventstest.SessionParams{PrintEvents: int64(eventsCount)}) + sid := inEvents[0].(events.SessionMetadataGetter).GetSessionID() + p.emitEvents(ctx, t, inEvents) + + // initiate the scan by advancing clock past + // block period + p.clock.Advance(p.scanPeriod + time.Second) + + var event events.UploadEvent + select { + case event = <-p.memEventsC: + if tc.expectEncryptedUploadErr { + t.Fatalf("Unexpected upload event") + } + require.Equal(t, event.SessionID, sid) + require.NoError(t, event.Error) + case event = <-p.eventsC: + if !tc.expectEncryptedUploadErr { + t.Fatalf("Unexpected upload event") + } + require.Error(t, event.Error) + require.True(t, isSessionError(event.Error)) + return + case <-ctx.Done(): + t.Fatalf("Timeout waiting for async upload, try `go test -v` to get more logs for details") + } + + // read the upload and make sure the data is equal + outEvents := p.readEvents(ctx, t, event.UploadID) + require.Equal(t, inEvents, outEvents) + + uploadParts, err := p.memUploader.GetParts(event.UploadID) + require.NoError(t, err) + + // final uploads should be above the minimum upload size. + for i, part := range uploadParts { + if i == len(uploadParts)-1 { + // The last part is not required to meet the minimum size, so it shouldn't exceed the original encrypted recording size. + require.LessOrEqual(t, len(part), expectEncryptedSizeCeil, "expected last upload to be smaller than %v bytes, but was %v bytes", expectEncryptedSizeCeil, len(part)) + } else { + require.GreaterOrEqual(t, len(part), expectFinalSizeFloor, "expected upload to be between %v and %v bytes, but was %v bytes", expectFinalSizeFloor, expectFinalSizeCeil, len(part)) + require.LessOrEqual(t, len(part), expectFinalSizeCeil, "expected upload to be between %v and %v bytes, but was %v bytes", expectFinalSizeFloor, expectFinalSizeCeil, len(part)) + } + } + + // There should be one final upload for each upload part from the encrypted uploader. + require.Len(t, recollectParts, len(uploadParts), "expected there to be an equal amount of final upload parts and transient upload parts, but got %v and %v respectively", len(uploadParts), len(recollectParts)) + }) + } +} + +type uploaderPackConfig struct { + minimumFileUploadBytes int64 + minimumUploadBytes int64 + wrapProtoStreamer wrapStreamerFn + encrypter events.EncryptionWrapper + wrapEncryptedUploader wrapEncryptedUploaderFn + encryptedRecordingUploadTargetSize int + encryptedRecordingUploadMaxSize int +} + // uploaderPack reduces boilerplate required // to create a test type uploaderPack struct { scanPeriod time.Duration initialScanDelay time.Duration clock *clockwork.FakeClock - eventsC chan events.UploadEvent - memEventsC chan events.UploadEvent - memUploader *eventstest.MemoryUploader - streamer events.Streamer - scanDir string - uploader *Uploader - ctx context.Context - cancel context.CancelFunc -} - -func (u *uploaderPack) Close(t *testing.T) { - u.cancel() + // fileStreamer streams events to upload parts on disk. + scanDir string + fileStreamer events.Streamer + // uploader scans upload parts from disk and streams them through the protoStreamer. + eventsC chan events.UploadEvent + uploader *Uploader + // protoStreamer streams events to upload parts in-memory (represents final audit log storage). + protoStreamer events.Streamer + memEventsC chan events.UploadEvent + memUploader *eventstest.MemoryUploader } -func newUploaderPack(t *testing.T, wrapStreamer wrapStreamerFn) uploaderPack { +func newUploaderPack(ctx context.Context, t *testing.T, cfg uploaderPackConfig) uploaderPack { scanDir := t.TempDir() corruptedDir := t.TempDir() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + pack := uploaderPack{ clock: clockwork.NewFakeClock(), eventsC: make(chan events.UploadEvent, 100), memEventsC: make(chan events.UploadEvent, 100), - ctx: ctx, - cancel: cancel, scanDir: scanDir, scanPeriod: 10 * time.Second, initialScanDelay: 10 * time.Millisecond, } - pack.memUploader = eventstest.NewMemoryUploader(pack.memEventsC) - streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ + pack.memUploader = eventstest.NewMemoryUploader(eventstest.MemoryUploaderConfig{ + EventsC: pack.memEventsC, + MinimumUploadBytes: int(cfg.minimumUploadBytes), + }) + + var err error + pack.fileStreamer, err = NewStreamer(StreamerConfig{ + Dir: scanDir, + MinUploadBytes: cfg.minimumFileUploadBytes, + Encrypter: cfg.encrypter, + }) + require.NoError(t, err) + + pack.protoStreamer, err = events.NewProtoStreamer(events.ProtoStreamerConfig{ Uploader: pack.memUploader, - MinUploadBytes: 1024, + MinUploadBytes: cfg.minimumUploadBytes, + // Skip encrypter in proto streamer, encryption skips the proto stream flow. }) require.NoError(t, err) - pack.streamer = streamer - if wrapStreamer != nil { - pack.streamer, err = wrapStreamer(pack.streamer) + + if cfg.wrapProtoStreamer != nil { + pack.protoStreamer, err = cfg.wrapProtoStreamer(pack.protoStreamer) require.NoError(t, err) } - uploader, err := NewUploader(UploaderConfig{ - ScanDir: pack.scanDir, - CorruptedDir: corruptedDir, - InitialScanDelay: pack.initialScanDelay, - ScanPeriod: pack.scanPeriod, - Streamer: pack.streamer, - Clock: pack.clock, - EventsC: pack.eventsC, - }) + uploaderCfg := UploaderConfig{ + ScanDir: pack.scanDir, + CorruptedDir: corruptedDir, + InitialScanDelay: pack.initialScanDelay, + ScanPeriod: pack.scanPeriod, + Streamer: pack.protoStreamer, + Clock: pack.clock, + EventsC: pack.eventsC, + EncryptedRecordingUploader: pack.memUploader, + EncryptedRecordingUploadTargetSize: cfg.encryptedRecordingUploadTargetSize, + EncryptedRecordingUploadMaxSize: cfg.encryptedRecordingUploadMaxSize, + } + if cfg.wrapEncryptedUploader != nil { + uploaderCfg.EncryptedRecordingUploader = cfg.wrapEncryptedUploader(pack.memUploader) + } + + pack.uploader, err = NewUploader(uploaderCfg) require.NoError(t, err) - pack.uploader = uploader - go pack.uploader.Serve(pack.ctx) + + go pack.uploader.Serve(ctx) + return pack } +func (p *uploaderPack) emitEvents(ctx context.Context, t *testing.T, inEvents []apievents.AuditEvent) { + emitStream(ctx, t, p.fileStreamer, inEvents) +} + +func (p *uploaderPack) readEvents(ctx context.Context, t *testing.T, uploadID string) []apievents.AuditEvent { + return readStream(ctx, t, uploadID, p.memUploader) +} + type wrapStreamerFn func(streamer events.Streamer) (events.Streamer, error) +type wrapEncryptedUploaderFn func(u events.EncryptedRecordingUploader) events.EncryptedRecordingUploader + // runResume runs resume scenario based on the test case specification func runResume(t *testing.T, testCase resumeTestCase) { t.Logf("Running test %q.", testCase.name) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() clock := clockwork.NewFakeClock() eventsC := make(chan events.UploadEvent, 100) - memUploader := eventstest.NewMemoryUploader(eventsC) + memUploader := eventstest.NewMemoryUploader(eventstest.MemoryUploaderConfig{ + EventsC: eventsC, + }) streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ Uploader: memUploader, MinUploadBytes: 1024, @@ -585,11 +929,13 @@ func runResume(t *testing.T, testCase resumeTestCase) { require.NoError(t, err) go uploader.Serve(ctx) // wait until uploader blocks on the clock - clock.BlockUntil(1) + clock.BlockUntilContext(ctx, 1) defer uploader.Close() - fileStreamer, err := NewStreamer(scanDir, nil) + fileStreamer, err := NewStreamer(StreamerConfig{ + Dir: scanDir, + }) require.NoError(t, err) inEvents := eventstest.GenerateTestSession(eventstest.SessionParams{PrintEvents: 1024}) @@ -618,7 +964,7 @@ func runResume(t *testing.T, testCase resumeTestCase) { // Block until Scan has been called two times, // first time after doing the scan, and second // on receiving the event to <- eventsCh - clock.BlockUntil(2) + clock.BlockUntilContext(ctx, 2) clock.Advance(scanPeriod*time.Duration(i+2) + time.Second) // wait for upload success @@ -646,6 +992,8 @@ func runResume(t *testing.T, testCase resumeTestCase) { // emitStream creates and sends the session stream func emitStream(ctx context.Context, t *testing.T, streamer events.Streamer, inEvents []apievents.AuditEvent) { + t.Helper() + sid := inEvents[0].(events.SessionMetadataGetter).GetSessionID() stream, err := streamer.CreateAuditStream(ctx, session.ID(sid)) @@ -669,7 +1017,7 @@ func readStream(ctx context.Context, t *testing.T, uploadID string, uploader *ev var reader *events.ProtoReader for i, part := range parts { if i == 0 { - reader = events.NewProtoReader(bytes.NewReader(part), nil) + reader = events.NewProtoReader(bytes.NewReader(part), &fakeEncryptedIO{}) } else { err := reader.Reset(bytes.NewReader(part)) require.NoError(t, err) @@ -681,3 +1029,41 @@ func readStream(ctx context.Context, t *testing.T, uploadID string, uploader *ev } return outEvents } + +// encryptedIO is really just a reversible transform, so we fake encryption by encoding/decoding as hex +type fakeEncryptedIO struct { + err error +} + +type fakeEncrypter struct { + inner io.WriteCloser + writer io.Writer +} + +func (f *fakeEncrypter) Write(out []byte) (int, error) { + return f.writer.Write(out) +} + +func (f *fakeEncrypter) Close() error { + return f.inner.Close() +} + +func (f *fakeEncryptedIO) WithEncryption(ctx context.Context, writer io.WriteCloser) (io.WriteCloser, error) { + hexWriter := hex.NewEncoder(writer) + encrypter := &fakeEncrypter{ + inner: writer, + writer: hexWriter, + } + + return encrypter, f.err +} + +func (f *fakeEncryptedIO) WithDecryption(ctx context.Context, reader io.Reader) (io.Reader, error) { + return hex.NewDecoder(reader), f.err +} + +type encryptedUploaderFn func(ctx context.Context, sessionID string, parts iter.Seq2[[]byte, error]) error + +func (e encryptedUploaderFn) UploadEncryptedRecording(ctx context.Context, sessionID string, parts iter.Seq2[[]byte, error]) error { + return e(ctx, sessionID, parts) +} diff --git a/lib/events/filesessions/filestream.go b/lib/events/filesessions/filestream.go index 49ce1b64113fc..2983b21f677ca 100644 --- a/lib/events/filesessions/filestream.go +++ b/lib/events/filesessions/filestream.go @@ -77,16 +77,43 @@ const ( reservationSize = minUploadBytes + events.MaxProtoMessageSizeBytes ) +type StreamerConfig struct { + // Dir is the dir to stream session events to. + Dir string + // MinUploadBytes is the minimum size at which upload parts are submitted. + // Due to the nature of the gzip writer, each upload part maybe be marginally + // larger, but not smaller, than the minimum size. Defaults to 128KB. + MinUploadBytes int64 + // Encrypter wraps the final gzip writer with encryption. + Encrypter events.EncryptionWrapper +} + +// CheckAndSetDefaults checks and sets streamer defaults +func (cfg *StreamerConfig) CheckAndSetDefaults() error { + if cfg.Dir == "" { + return trace.BadParameter("missing parameter Dir") + } + if cfg.MinUploadBytes == 0 { + cfg.MinUploadBytes = minUploadBytes + } + return nil +} + // NewStreamer creates a streamer sending uploads to disk -func NewStreamer(dir string, encrypter events.EncryptionWrapper) (*events.ProtoStreamer, error) { - handler, err := NewHandler(Config{Directory: dir, OpenFile: GetOpenFileFunc()}) +func NewStreamer(cfg StreamerConfig) (*events.ProtoStreamer, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + handler, err := NewHandler(Config{Directory: cfg.Dir, OpenFile: GetOpenFileFunc()}) if err != nil { return nil, trace.Wrap(err) } + return events.NewProtoStreamer(events.ProtoStreamerConfig{ Uploader: handler, - MinUploadBytes: minUploadBytes, - Encrypter: encrypter, + MinUploadBytes: cfg.MinUploadBytes, + Encrypter: cfg.Encrypter, }) } diff --git a/lib/events/recorder/recorder.go b/lib/events/recorder/recorder.go index 916704dae5a38..4d8a221958c21 100644 --- a/lib/events/recorder/recorder.go +++ b/lib/events/recorder/recorder.go @@ -139,7 +139,10 @@ func New(cfg Config) (events.SessionPreparerRecorder, error) { if cfg.Encrypter == nil { cfg.Encrypter = recordingencryption.NewEncryptionWrapper(cfg.RecordingCfg) } - fileStreamer, err := filesessions.NewStreamer(uploadDir, cfg.Encrypter) + fileStreamer, err := filesessions.NewStreamer(filesessions.StreamerConfig{ + Dir: uploadDir, + Encrypter: cfg.Encrypter, + }) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/events/session_writer_test.go b/lib/events/session_writer_test.go index 6ad4b7464ffcf..bf06dad808206 100644 --- a/lib/events/session_writer_test.go +++ b/lib/events/session_writer_test.go @@ -352,7 +352,9 @@ func withBackoff(timeout, dur time.Duration) sessionWriterOption { func newSessionWriterTest(t *testing.T, newStreamer newStreamerFn, opts ...sessionWriterOption) *sessionWriterTest { eventsCh := make(chan events.UploadEvent, 1) - uploader := eventstest.NewMemoryUploader(eventsCh) + uploader := eventstest.NewMemoryUploader(eventstest.MemoryUploaderConfig{ + EventsC: eventsCh, + }) protoStreamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ Uploader: uploader, }) diff --git a/lib/events/stream.go b/lib/events/stream.go index 38e85af6c8baa..dd96617848432 100644 --- a/lib/events/stream.go +++ b/lib/events/stream.go @@ -68,8 +68,14 @@ const ( // MaxProtoMessageSizeBytes is maximum protobuf marshaled message size MaxProtoMessageSizeBytes = 64 * 1024 - // MinUploadPartSizeBytes is the minimum allowed part size when uploading a part to - // Amazon S3. + // MinUploadPartSizeBytes is the minimum upload part size when uploading session recordings + // through a [MultipartUploader]. All uploaded parts are expected to meet this minimum size. + // The actual minimum enforced by th external audit storage depends on the provider: + // - S3 (AWS): 5MiB + // - GCloud: 5MiB + // - Azure: None + // - File: None + // - Mem (tests): Configurable MinUploadPartSizeBytes = 1024 * 1024 * 5 // ProtoStreamV1 is a version of the binary protocol @@ -1302,6 +1308,21 @@ func (r *ProtoReader) Read(ctx context.Context) (apievents.AuditEvent, error) { return nil, r.setError(trace.ConvertSystemError(err)) } + // Empty parts may be created for padding. Just skip them and discard any padding. + if header.PartSize == 0 { + if header.PaddingSize != 0 { + skipped, err := io.CopyBuffer(io.Discard, io.LimitReader(r.reader, int64(header.PaddingSize)), r.messageBytes[:]) + if err != nil { + return nil, r.setError(trace.ConvertSystemError(err)) + } + if skipped != int64(header.PaddingSize) { + return nil, r.setError(trace.BadParameter( + "data truncated, expected to read %v bytes, but got %v", r.padding, skipped)) + } + } + continue + } + r.padding = int64(header.PaddingSize) reader := io.LimitReader(r.reader, int64(header.PartSize)) if header.Flags&ProtoStreamFlagEncrypted != 0 { diff --git a/lib/events/stream_test.go b/lib/events/stream_test.go index 048b8aafdfbf0..06c1d7c8e60dd 100644 --- a/lib/events/stream_test.go +++ b/lib/events/stream_test.go @@ -197,7 +197,7 @@ func TestProtoStreamLargeEvent(t *testing.T) { ctx := context.Background() streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ - Uploader: eventstest.NewMemoryUploader(nil), + Uploader: eventstest.NewMemoryUploader(), }) require.NoError(t, err) diff --git a/lib/service/service.go b/lib/service/service.go index ed1a12dc63064..7363d899524e8 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3655,6 +3655,7 @@ func (process *TeleportProcess) initUploaderService() error { // use the local auth server for uploads if auth happens to be // running in this process, otherwise wait for the instance client var uploaderClient procUploader + var encryptedRecordingMaxUploadSize int if la := process.getLocalAuth(); la != nil { // The auth service's upload completer is initialized separately, // so as a special case we can stop early if auth happens to be @@ -3670,6 +3671,7 @@ func (process *TeleportProcess) initUploaderService() error { return trace.Wrap(err, "cannot get cluster name") } clusterName = cn.GetClusterName() + } else { logger.DebugContext(process.ExitContext(), "auth is not running in-process, waiting for instance connector") conn, err := waitForInstanceConnector(process, logger) @@ -3681,6 +3683,10 @@ func (process *TeleportProcess) initUploaderService() error { } uploaderClient = conn.Client clusterName = conn.ClusterName() + + // encrypted uploads are aggregated and uploaded directly rather than with an event stream. + // Since we are using the gRPC client, we must set this maximum for the aggregation step. + encryptedRecordingMaxUploadSize = 4 * 1024 * 1024 // 4MiB, default gRPC max msg recv size. } logger.InfoContext(process.ExitContext(), "starting upload completer service") @@ -3719,12 +3725,13 @@ func (process *TeleportProcess) initUploaderService() error { corruptedDir := filepath.Join(paths[1]...) fileUploader, err := filesessions.NewUploader(filesessions.UploaderConfig{ - Streamer: uploaderClient, - ScanDir: uploadsDir, - CorruptedDir: corruptedDir, - EventsC: process.Config.Testing.UploadEventsC, - InitialScanDelay: 15 * time.Second, - EncryptedRecordingUploader: uploaderClient, + Streamer: uploaderClient, + ScanDir: uploadsDir, + CorruptedDir: corruptedDir, + EventsC: process.Config.Testing.UploadEventsC, + InitialScanDelay: 15 * time.Second, + EncryptedRecordingUploader: uploaderClient, + EncryptedRecordingUploadMaxSize: encryptedRecordingMaxUploadSize, }) if err != nil { return trace.Wrap(err)