Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions lib/events/pgevents/pgevents.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ const (
)

const (
// A note on "session_id uuid NOT NULL":
//
// Some session IDs aren't UUIDs. See [Log.deriveSessionID] for an example.
// The wiser choice of type would be "session_id text", ie, handling session
// IDs as an opaque identifier.
//
// If you are writing a new backend and stumbled on this comment, do not use
// a storage UUID type for session IDs. Use a string type.
schemaV1Table = `CREATE TABLE events (
event_time timestamptz NOT NULL,
event_id uuid NOT NULL,
Expand Down Expand Up @@ -360,21 +368,14 @@ var _ events.AuditLogger = (*Log)(nil)
// EmitAuditEvent implements [events.AuditLogger].
func (l *Log) EmitAuditEvent(ctx context.Context, event apievents.AuditEvent) error {
ctx = context.WithoutCancel(ctx)
var sessionID uuid.UUID
if s := events.GetSessionID(event); s != "" {
u, err := uuid.Parse(s)
if err != nil {
return trace.Wrap(err)
}
sessionID = u
}

eventJSON, err := utils.FastMarshal(event)
if err != nil {
return trace.Wrap(err)
}

eventID := uuid.New()
sessionID := l.deriveSessionID(ctx, events.GetSessionID(event))

start := time.Now()
// if an event with the same event_id exists, it means that we inserted it
Expand Down Expand Up @@ -412,15 +413,6 @@ func (l *Log) searchEvents(
eventTypes []string, cond *types.WhereExpr, sessionID string,
limit int, order types.EventOrder, startKey string,
) ([]apievents.AuditEvent, string, error) {
var sessionUUID uuid.UUID
if sessionID != "" {
var err error
sessionUUID, err = uuid.Parse(sessionID)
if err != nil {
return nil, "", trace.Wrap(err)
}
}

if limit <= 0 {
limit = defaults.EventsIterationLimit
}
Expand All @@ -447,6 +439,8 @@ func (l *Log) searchEvents(
}
}

sessionUUID := l.deriveSessionID(ctx, sessionID)

var qb strings.Builder
qb.WriteString("DECLARE cur CURSOR FOR SELECT" +
" events.event_time, events.event_id, events.event_data" +
Expand Down Expand Up @@ -595,3 +589,40 @@ func (l *Log) GetEventExportChunks(ctx context.Context, req *auditlogpb.GetEvent
func (l *Log) SearchSessionEvents(ctx context.Context, req events.SearchSessionEventsRequest) ([]apievents.AuditEvent, string, error) {
return l.searchEvents(ctx, req.From, req.To, events.SessionRecordingEvents, req.Cond, req.SessionID, req.Limit, req.Order, req.StartKey)
}

// sessionIDBase is a randomly-generated UUID used as the basis for deriving
// an UUID from session IDs. See [Log.deriveSessionID].
var sessionIDBase = uuid.MustParse("e481e221-77b0-4b9e-be98-bc2e486b751b")

func (l *Log) deriveSessionID(ctx context.Context, sessionID string) uuid.UUID {
if sessionID == "" {
return uuid.Nil // return zero UUID for backwards compat
}

u, err := uuid.Parse(sessionID)
if err == nil {
return u
}

// Some session IDs aren't UUIDs. For example, App session IDs are 32-byte
// values encoded as hex. Whether the assumption of UUIDs is philosophically
// correct is immaterial, what matters is that we do not drop the audit
// event.
//
// To avoid dropping the event while conforming to the existing schema we
// deterministically derive an UUID from the session ID.
//
// Note that derived IDs are UUIDv5 (instead of the usual UUIDv4 from
// uuid.Parse), so that could be used as a hint to which UUIDs are original or
// derived.
//
// * https://github.com/gravitational/teleport/blob/63537e3da5a22b61d9218863f1ed535a31d229ea/lib/auth/sessions.go#L521
derived := uuid.NewSHA1(sessionIDBase, []byte(sessionID))

l.log.DebugContext(ctx,
"Failed to parse event session ID, using derived ID",
"error", err,
"derived_id", derived,
)
return derived
}
135 changes: 111 additions & 24 deletions lib/events/pgevents/pgevents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"

apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
pgcommon "github.com/gravitational/teleport/lib/backend/pgbk/common"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/test"
"github.com/gravitational/teleport/lib/utils"
)
Expand All @@ -38,52 +45,132 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

// TELEPORT_TEST_PGEVENTS_URL is a connection string similar to the one used by
// "audit_events_uri" (in teleport.yaml).
// For example: "postgresql://teleport@localhost:5432/teleport_test1?sslcert=/path/to/cert.pem&sslkey=/path/to/key.pem&sslrootcert=/path/to/ca.pem&sslmode=verify-full"
const urlEnvVar = "TELEPORT_TEST_PGEVENTS_URL"

func TestPostgresEvents(t *testing.T) {
s, ok := os.LookupEnv(urlEnvVar)
if !ok {
t.Skipf("Missing %v environment variable.", urlEnvVar)
}

u, err := url.Parse(s)
require.NoError(t, err)

var cfg Config
require.NoError(t, cfg.SetFromURL(u))
// Don't t.Parallel(), relies on the database backed by urlEnvVar.
log := newLogForTesting(t)

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
ctx := context.Background()

log, err := New(ctx, cfg)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, log.Close()) })
truncateEvents := func(t *testing.T) {
_, err := log.pool.Exec(ctx, "TRUNCATE events")
require.NoError(t, err, "truncate events")
}

suite := test.EventsSuite{
Log: log,
Clock: clockwork.NewRealClock(),
}

// the tests in the suite expect a blank slate each time
setup := func(t *testing.T) {
_, err := log.pool.Exec(ctx, "TRUNCATE events")
require.NoError(t, err)
}

t.Run("SessionEventsCRUD", func(t *testing.T) {
setup(t)
// The tests in the suite expect a blank slate each time.
truncateEvents(t)
suite.SessionEventsCRUD(t)
})
t.Run("EventPagination", func(t *testing.T) {
setup(t)
truncateEvents(t)
suite.EventPagination(t)
})
t.Run("SearchSessionEventsBySessionID", func(t *testing.T) {
setup(t)
truncateEvents(t)
suite.SearchSessionEventsBySessionID(t)
})
}

// TestLog_nonStandardSessionID tests for
// https://github.com/gravitational/teleport/issues/46207.
func TestLog_nonStandardSessionID(t *testing.T) {
// Don't t.Parallel(), relies on the database backed by urlEnvVar.
eventsLog := newLogForTesting(t)

// Example app event. Only the session ID matters for the test, everything
// else is realistic but irrelevant here.
eventTime := time.Now()
appStartEvent := &apievents.AppSessionStart{
Metadata: apievents.Metadata{
Type: events.AppSessionStartEvent,
Code: events.AppSessionStartCode,
ClusterName: "zarq",
Time: eventTime,
},
ServerMetadata: apievents.ServerMetadata{
ServerVersion: "17.2.2",
ServerID: "18d877c6-c8ab-46fc-9806-b638c0d6c556",
ServerNamespace: apidefaults.Namespace,
},
SessionMetadata: apievents.SessionMetadata{
// IMPORTANT: not an UUID!
SessionID: "f8571503d72f35938ce5001b792baebcce3183719ae947fde1ed685f7848facc",
},
UserMetadata: apievents.UserMetadata{
User: "alpaca",
UserKind: apievents.UserKind_USER_KIND_HUMAN,
},
PublicAddr: "dumper.zarq.dev",
AppMetadata: apievents.AppMetadata{
AppURI: "http://127.0.0.1:52932",
AppPublicAddr: "dumper.zarq.dev",
AppName: "dumper",
},
}

ctx := context.Background()

// Emit event with non-standard session ID.
require.NoError(t,
eventsLog.EmitAuditEvent(ctx, appStartEvent),
"emit audit event",
)

// Search event by the same (non-standard) session ID.
// SearchSessionEvents has a hard-coded list of eventTypes that excludes App
// events, so we must use searchEvents instead.
before := eventTime.Add(-1 * time.Second)
after := eventTime.Add(1 * time.Second)
appEvents, _, err := eventsLog.searchEvents(ctx,
before, // fromTime
after, // toTime
[]string{appStartEvent.Metadata.Type}, // eventTypes
nil, // cond
appStartEvent.SessionID,
2, // limit
types.EventOrderAscending,
"", // startKey
)
require.NoError(t, err, "search session events")
want := []apievents.AuditEvent{appStartEvent}
if diff := cmp.Diff(want, appEvents, protocmp.Transform()); diff != "" {
t.Errorf("searchEvents mismatch (-want +got)\n%s", diff)
}
}

func newLogForTesting(t *testing.T) *Log {
t.Helper()

connString, ok := os.LookupEnv(urlEnvVar)
if !ok {
t.Skipf("Missing %v environment variable.", urlEnvVar)
}
u, err := url.Parse(connString)
require.NoError(t, err, "parse Postgres connString from %s", urlEnvVar)

var cfg Config
require.NoError(t, cfg.SetFromURL(u), "cfg.SetFromURL")

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

eventsLog, err := New(ctx, cfg)
require.NoError(t, err, "create new Log")
t.Cleanup(func() { assert.NoError(t, eventsLog.Close(), "close events log") })

return eventsLog
}

func TestConfig(t *testing.T) {
configs := map[string]*Config{
"postgres://foo#auth_mode=azure": {
Expand Down