Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion api/metadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import (
)

const (
VersionKey = "version"
VersionKey = "version"
SessionRecordingFormatContextKey = "session-recording-format"
)

// defaultMetadata returns the default metadata which will be added to all outgoing calls.
Expand Down Expand Up @@ -133,3 +134,20 @@ func UserAgentFromContext(ctx context.Context) string {
}
return strings.Join(values, " ")
}

// WithSessionRecordingFormatContext returns a context.Context containing the
// format of the accessed session recording.
func WithSessionRecordingFormatContext(ctx context.Context, format string) context.Context {
return metadata.AppendToOutgoingContext(ctx, SessionRecordingFormatContextKey, format)
}

// SessionRecordingFormatFromContext returns the format of the accessed session
// recording (if present).
func SessionRecordingFormatFromContext(ctx context.Context) string {
values := metadata.ValueFromIncomingContext(ctx, SessionRecordingFormatContextKey)
if len(values) == 0 {
return ""
}

return values[0]
}
4 changes: 4 additions & 0 deletions api/proto/teleport/legacy/types/events/events.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5637,6 +5637,10 @@ message SessionRecordingAccess {
(gogoproto.embed) = true,
(gogoproto.jsontag) = ""
];
// SessionType is type of the session.
string SessionType = 4 [(gogoproto.jsontag) = "session_type,omitempty"];
// Format is the format the session recording was accessed.
string Format = 5 [(gogoproto.jsontag) = "format,omitempty"];
}

// KubeClusterMetadata contains common kubernetes cluster information.
Expand Down
1,384 changes: 738 additions & 646 deletions api/types/events/events.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions api/types/session_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
DatabaseSessionKind SessionKind = "db"
AppSessionKind SessionKind = "app"
WindowsDesktopSessionKind SessionKind = "desktop"
UnknownSessionKind SessionKind = ""
)

// SessionParticipantMode is the mode that determines what you can do when you join a session.
Expand Down
59 changes: 41 additions & 18 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1"
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
"github.com/gravitational/teleport/api/internalutils/stream"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
Expand Down Expand Up @@ -193,14 +194,15 @@ func (a *ServerWithRoles) actionWithExtendedContext(namespace, kind, verb string
}

// actionForKindSession is a special checker that grants access to session
// recordings. It can allow access to a specific recording based on the
// recordings. It can allow access to a specific recording based on the
// `where` section of the user's access rule for kind `session`.
func (a *ServerWithRoles) actionForKindSession(namespace string, sid session.ID) error {
extendContext := func(ctx *services.Context) error {
sessionEnd, err := a.findSessionEndEvent(namespace, sid)
ctx.Session = sessionEnd
return trace.Wrap(err)
}

return trace.Wrap(a.actionWithExtendedContext(namespace, types.KindSession, types.VerbRead, extendContext))
}

Expand Down Expand Up @@ -6226,39 +6228,60 @@ func (a *ServerWithRoles) ReplaceRemoteLocks(ctx context.Context, clusterName st
// channel if one is encountered. Otherwise the event channel is closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
createErrorChannel := func(err error) (chan apievents.AuditEvent, chan error) {
e := make(chan error, 1)
e <- trace.Wrap(err)
return nil, e
}

err := a.localServerAction()
isTeleportServer := err == nil

if !isTeleportServer {
if err := a.actionForKindSession(apidefaults.Namespace, sessionID); err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}
// StreamSessionEvents can be called internally, and when that
// happens we don't want to emit an event or check for permissions.
if isTeleportServer {
return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
}

// StreamSessionEvents can be called internally, and when that happens we don't want to emit an event.
shouldEmitAuditEvent := !isTeleportServer
if shouldEmitAuditEvent {
if err := a.actionForKindSession(apidefaults.Namespace, sessionID); err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}

// We can only determine the session type after the streaming started. For
// this reason, we delay the emit audit event until the first event or if
// the streaming returns an error.
cb := func(evt apievents.AuditEvent, _ error) {
if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{
Metadata: apievents.Metadata{
Type: events.SessionRecordingAccessEvent,
Code: events.SessionRecordingAccessCode,
},
SessionID: sessionID.String(),
UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(),
SessionType: string(sessionTypeFromStartEvent(evt)),
Format: metadata.SessionRecordingFormatFromContext(ctx),
}); err != nil {
return createErrorChannel(err)
log.WithError(err).Errorf("Failed to emit stream session event audit event")
}
}

return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
return a.alog.StreamSessionEvents(events.ContextWithSessionStartCallback(ctx, cb), sessionID, startIndex)
}

// sessionTypeFromStartEvent determines the session type given the session start
// event.
func sessionTypeFromStartEvent(sessionStart apievents.AuditEvent) types.SessionKind {
switch e := sessionStart.(type) {
case *apievents.SessionStart:
if e.KubernetesCluster != "" {
return types.KubernetesSessionKind
}
return types.SSHSessionKind
case *apievents.DatabaseSessionStart:
return types.DatabaseSessionKind
case *apievents.AppSessionStart:
return types.AppSessionKind
case *apievents.WindowsDesktopSessionStart:
return types.WindowsDesktopSessionKind
default:
return types.UnknownSessionKind
}
}

// CreateApp creates a new application resource.
Expand Down
84 changes: 84 additions & 0 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
trustpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/trust/v1"
userpreferencesv1 "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
Expand Down Expand Up @@ -2099,6 +2100,89 @@ func TestStreamSessionEvents(t *testing.T) {
require.Equal(t, username, event.User)
}

// TestStreamSessionEvents ensures that when a user streams a session's events
// a "session recording access" event is emitted with the correct session type.
func TestStreamSessionEvents_SessionType(t *testing.T) {
t.Parallel()

authServerConfig := TestAuthServerConfig{
Dir: t.TempDir(),
Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()),
}
require.NoError(t, authServerConfig.CheckAndSetDefaults())

uploader := eventstest.NewMemoryUploader()
localLog, err := events.NewAuditLog(events.AuditLogConfig{
DataDir: authServerConfig.Dir,
ServerID: authServerConfig.ClusterName,
Clock: authServerConfig.Clock,
UploadHandler: uploader,
})
require.NoError(t, err)
authServerConfig.AuditLog = localLog

as, err := NewTestAuthServer(authServerConfig)
require.NoError(t, err)

srv, err := as.NewTestTLSServer()
require.NoError(t, err)
t.Cleanup(func() { srv.Close() })

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

username := "user"
user, _, err := CreateUserAndRole(srv.Auth(), username, []string{}, nil)
require.NoError(t, err)

identity := TestUser(user.GetName())
clt, err := srv.NewClient(identity)
require.NoError(t, err)
sessionID := session.NewID()

streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{
Uploader: uploader,
})
require.NoError(t, err)
stream, err := streamer.CreateAuditStream(ctx, sessionID)
require.NoError(t, err)
// The event is not required to pass through the auth server, we only need
// the upload to be present.
require.NoError(t, stream.RecordEvent(ctx, eventstest.PrepareEvent(&apievents.DatabaseSessionStart{
Metadata: apievents.Metadata{
Type: events.DatabaseSessionStartEvent,
Code: events.DatabaseSessionStartCode,
},
SessionMetadata: apievents.SessionMetadata{
SessionID: sessionID.String(),
},
})))
require.NoError(t, stream.Complete(ctx))

accessedFormat := teleport.PTY
clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), sessionID, 0)

// Perform the listing an eventually loop to ensure the event is emitted.
var searchEvents []apievents.AuditEvent
require.EventuallyWithT(t, func(t *assert.CollectT) {
var err error
searchEvents, _, err = srv.AuthServer.AuditLog.SearchEvents(ctx, events.SearchEventsRequest{
From: srv.Clock().Now().Add(-time.Hour),
To: srv.Clock().Now().Add(time.Hour),
EventTypes: []string{events.SessionRecordingAccessEvent},
Limit: 1,
Order: types.EventOrderDescending,
})
assert.NoError(t, err)
assert.Len(t, searchEvents, 1, "expected one event but got %d", len(searchEvents))
}, 5*time.Second, 200*time.Millisecond)

event := searchEvents[0].(*apievents.SessionRecordingAccess)
require.Equal(t, username, event.User)
require.Equal(t, string(types.DatabaseSessionKind), event.SessionType)
require.Equal(t, accessedFormat, event.Format)
}

// TestAPILockedOut tests Auth API when there are locks involved.
func TestAPILockedOut(t *testing.T) {
t.Parallel()
Expand Down
60 changes: 60 additions & 0 deletions lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -978,9 +978,23 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
e := make(chan error, 1)
c := make(chan apievents.AuditEvent)

sessionStartCh := make(chan apievents.AuditEvent, 1)
if startCb, err := sessionStartCallbackFromContext(ctx); err == nil {
go func() {
evt, ok := <-sessionStartCh
if !ok {
startCb(nil, trace.NotFound("session start event not found"))
return
}

startCb(evt, nil)
}()
}

rawSession, err := os.CreateTemp(l.playbackDir, string(sessionID)+".stream.tar.*")
if err != nil {
e <- trace.Wrap(trace.ConvertSystemError(err), "creating temporary stream file")
close(sessionStartCh)
return c, e
}
// The file is still perfectly usable after unlinking it, and the space it's
Expand All @@ -997,6 +1011,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
if err := os.Remove(rawSession.Name()); err != nil {
_ = rawSession.Close()
e <- trace.Wrap(trace.ConvertSystemError(err), "removing temporary stream file")
close(sessionStartCh)
return c, e
}

Expand All @@ -1007,6 +1022,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
err = trace.NotFound("a recording for session %v was not found", sessionID)
}
e <- trace.Wrap(err)
close(sessionStartCh)
return c, e
}
l.log.WithFields(log.Fields{
Expand All @@ -1016,6 +1032,8 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID

go func() {
defer rawSession.Close()
defer close(sessionStartCh)

// this shouldn't be necessary as the position should be already 0 (Download
// takes an io.WriterAt), but it's better to be safe than sorry
if _, err := rawSession.Seek(0, io.SeekStart); err != nil {
Expand All @@ -1026,6 +1044,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
protoReader := NewProtoReader(rawSession)
defer protoReader.Close()

firstEvent := true
for {
if ctx.Err() != nil {
e <- trace.Wrap(ctx.Err())
Expand All @@ -1042,6 +1061,11 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
return
}

if firstEvent {
sessionStartCh <- event
firstEvent = false
}

if event.GetIndex() >= startIndex {
select {
case c <- event:
Expand Down Expand Up @@ -1136,3 +1160,39 @@ func (l *AuditLog) periodicSpaceMonitor() {
}
}
}

// streamSessionEventsContextKey represent context keys used by
// StreamSessionEvents function.
type streamSessionEventsContextKey string

const (
// sessionStartCallbackContextKey is the context key used to store the
// session start callback function.
sessionStartCallbackContextKey streamSessionEventsContextKey = "session-start"
)

// SessionStartCallback is the function used when streaming reaches the start
// event. If any error, such as session not found, the event will be nil, and
// the error will be set.
type SessionStartCallback func(startEvent apievents.AuditEvent, err error)

// ContextWithSessionStartCallback returns a context.Context containing a
// session start event callback.
func ContextWithSessionStartCallback(ctx context.Context, cb SessionStartCallback) context.Context {
return context.WithValue(ctx, sessionStartCallbackContextKey, cb)
}

// sessionStartCallbackFromContext returns the session start callback from
// context.Context.
func sessionStartCallbackFromContext(ctx context.Context) (SessionStartCallback, error) {
if ctx == nil {
return nil, trace.BadParameter("context is nil")
}

cb, ok := ctx.Value(sessionStartCallbackContextKey).(SessionStartCallback)
if !ok {
return nil, trace.BadParameter("session start callback function was not found in the context")
}

return cb, nil
}
Loading