Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1781,10 +1781,12 @@ func TestStreamSessionEventsRBAC(t *testing.T) {
clt, err := srv.NewClient(identity)
require.NoError(t, err)

_, errC := clt.StreamSessionEvents(context.Background(), "foo", 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, errC := clt.StreamSessionEvents(ctx, "foo", 0)
select {
case err := <-errC:
require.True(t, trace.IsAccessDenied(err), "expected access denied error, got %v", err)
require.ErrorAs(t, err, new(*trace.AccessDeniedError))
case <-time.After(5 * time.Second):
require.FailNow(t, "expected access denied error but stream succeeded")
}
Expand All @@ -1794,7 +1796,8 @@ func TestStreamSessionEventsRBAC(t *testing.T) {
func TestStreamSessionEvents_User(t *testing.T) {
t.Parallel()

ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
srv := newTestTLSServer(t)

username := "user"
Expand Down Expand Up @@ -1829,7 +1832,8 @@ func TestStreamSessionEvents_User(t *testing.T) {
func TestStreamSessionEvents_Builtin(t *testing.T) {
t.Parallel()

ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
srv := newTestTLSServer(t)

identity := TestBuiltin(types.RoleProxy)
Expand Down Expand Up @@ -7275,7 +7279,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) {
roles: []string{
"identity-requester", "payments-requester",
"identity-resource-requester", "payments-resource-requester",
"any-requester"},
"any-requester",
},
requestedRoles: []string{"payments-access"},
expectedAnnotations: map[string][]string{
"requesting": {"role"},
Expand All @@ -7289,7 +7294,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) {
roles: []string{
"identity-requester", "payments-requester",
"identity-resource-requester", "payments-resource-requester",
"any-requester"},
"any-requester",
},
requestedRoles: []string{"payments-access"},
requestedResourceIDs: []string{"server-payments"},
expectedAnnotations: map[string][]string{
Expand Down Expand Up @@ -7388,10 +7394,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) {
} else {
tc.errfn(t, err)
}

})
}

}

func mustAccessRequest(t *testing.T, user string, state types.RequestState, created, expires time.Time, roles []string, resourceIDs []types.ResourceID) types.AccessRequest {
Expand Down
18 changes: 8 additions & 10 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,14 @@ func (g *GRPCServer) EmitAuditEvent(ctx context.Context, req *apievents.OneOf) (
return &emptypb.Empty{}, nil
}

var (
connectedResourceGauges = map[string]prometheus.Gauge{
constants.KeepAliveNode: connectedResources.WithLabelValues(constants.KeepAliveNode),
constants.KeepAliveKube: connectedResources.WithLabelValues(constants.KeepAliveKube),
constants.KeepAliveApp: connectedResources.WithLabelValues(constants.KeepAliveApp),
constants.KeepAliveDatabase: connectedResources.WithLabelValues(constants.KeepAliveDatabase),
constants.KeepAliveDatabaseService: connectedResources.WithLabelValues(constants.KeepAliveDatabaseService),
constants.KeepAliveWindowsDesktopService: connectedResources.WithLabelValues(constants.KeepAliveWindowsDesktopService),
}
)
var connectedResourceGauges = map[string]prometheus.Gauge{
constants.KeepAliveNode: connectedResources.WithLabelValues(constants.KeepAliveNode),
constants.KeepAliveKube: connectedResources.WithLabelValues(constants.KeepAliveKube),
constants.KeepAliveApp: connectedResources.WithLabelValues(constants.KeepAliveApp),
constants.KeepAliveDatabase: connectedResources.WithLabelValues(constants.KeepAliveDatabase),
constants.KeepAliveDatabaseService: connectedResources.WithLabelValues(constants.KeepAliveDatabaseService),
constants.KeepAliveWindowsDesktopService: connectedResources.WithLabelValues(constants.KeepAliveWindowsDesktopService),
}

// SendKeepAlives allows node to send a stream of keep alive requests
func (g *GRPCServer) SendKeepAlives(stream authpb.AuthService_SendKeepAlivesServer) error {
Expand Down
1 change: 0 additions & 1 deletion lib/client/player.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ func (p *sessionPlayer) playRange(from, to int) {
var i int

defer func() {

p.Lock()
endRequested := p.state == stateEnding
p.setState(stateStopped)
Expand Down
10 changes: 7 additions & 3 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,13 @@ type SessionStreamer interface {
// after is used to return events after a specified cursor ID
GetSessionEvents(namespace string, sid session.ID, after int) ([]EventFields, error)

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise the event channel is closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
// StreamSessionEvents streams all events from a given session recording. An
// error is returned on the first channel if one is encountered. Otherwise
// the event channel is closed when the stream ends. The event channel is
// not closed on error to prevent race conditions in downstream select
// statements. Both returned channels must be driven until the event channel
// is exhausted or the error channel reports an error, or until the context
// is canceled.
StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error)
}

Expand Down
4 changes: 1 addition & 3 deletions lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -950,9 +950,7 @@ func (l *AuditLog) SearchSessionEvents(ctx context.Context, req SearchSessionEve
return l.localLog.SearchSessionEvents(ctx, req)
}

// StreamSessionEvents streams all events from a given session recording. An error is returned on the first
// channel if one is encountered. Otherwise the event channel is closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
// StreamSessionEvents implements [SessionStreamer].
func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
l.log.Debugf("StreamSessionEvents(%v)", sessionID)
e := make(chan error, 1)
Expand Down
2 changes: 2 additions & 0 deletions lib/events/complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ func (u *UploadCompleter) ensureSessionEndEvent(ctx context.Context, uploadData
// for both Desktop and SSH sessions, where as the GetSessionEvents API relies on downloading
// a copy of the session and using the SSH-specific index to iterate through events.
var lastEvent events.AuditEvent
ctx, cancel := context.WithCancel(ctx)
defer cancel()
evts, errors := u.cfg.AuditLog.StreamSessionEvents(ctx, uploadData.SessionID, 0)

loop:
Expand Down
5 changes: 4 additions & 1 deletion tool/tsh/common/recording_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ func makeAVIFileName(prefix string, currentFile int) string {
// writeMovie writes the events for the specified session into one or more movie files
// beginning with the specified prefix. It returns the number of frames that were written and an error.
func writeMovie(ctx context.Context, ss events.SessionStreamer, sid session.ID, prefix string,
write func(format string, args ...any) (int, error)) (frames int, err error) {
write func(format string, args ...any) (int, error),
) (frames int, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

var screen *image.NRGBA
var movie mjpeg.AviWriter
Expand Down