Skip to content

Commit

Permalink
Initialise: handle state blocks from a gappy sync
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson committed Apr 17, 2023
1 parent 666823d commit accec63
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 26 deletions.
117 changes: 91 additions & 26 deletions state/accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,31 @@ type InitialiseResult struct {
// This function:
// - Stores these events
// - Sets up the current snapshot based on the state list given.
func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (res InitialiseResult, err error) {
func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (res InitialiseResult, outerErr error) {
if len(state) == 0 {
return res, nil
}
err = sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
// Attempt to short-circuit. This has to be done inside a transaction to make sure
outerErr = sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
// This has to be done inside a transaction to make sure
// we don't race with multiple calls to Initialise with the same room ID.
snapshotID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return fmt.Errorf("error fetching snapshot id for room %s: %s", roomID, err)
}
if snapshotID > 0 {
// we only initialise rooms once
logger.Info().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early")
return nil
unknownRoom := snapshotID == 0
if !unknownRoom {
const warningMsg = "Accumulator.Initialise called when current snapshot already exists. Patching in events"
logger.Warn().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg(warningMsg)
sentry.WithScope(func(scope *sentry.Scope) {
scope.SetContext("sliding-sync", map[string]interface{}{
"room_id": roomID,
"snapshot_id": snapshotID,
})
sentry.CaptureException(fmt.Errorf(warningMsg))
})
}

// Insert the events
// Parse the events
events := make([]Event, len(state))
for i := range events {
events[i] = Event{
Expand All @@ -176,10 +183,41 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (res In
IsState: true,
}
}
if err := ensureFieldsSet(events); err != nil {
if err = ensureFieldsSet(events); err != nil {
return fmt.Errorf("events malformed: %s", err)
}
eventIDToNID, err := a.eventsTable.Insert(txn, events, false)

// Determine which events should be inserted.
var insertEvents []Event
if unknownRoom {
insertEvents = events
} else {
// Select the events which do not have a NID
eventIDs := make([]string, len(events))
for i := range events {
eventIDs[i] = events[i].ID
}
unknownEventIDs, err := a.eventsTable.SelectUnknownEventIDs(txn, eventIDs)
if err != nil {
return fmt.Errorf("error determing which event IDs are unknown")
}
if len(unknownEventIDs) == 0 {
// All events known. Odd, but nothing to do.
return nil
}
Outer:
for i := range events {
for j := range unknownEventIDs {
if events[i].ID == unknownEventIDs[j] {
insertEvents = append(insertEvents, events[i])
continue Outer
}
}
}
}

// Insert new events
eventIDToNID, err := a.eventsTable.Insert(txn, insertEvents, false)
if err != nil {
return fmt.Errorf("failed to insert events: %w", err)
}
Expand All @@ -192,24 +230,45 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (res In
return nil
}

// pull out the event NIDs we just inserted
membershipEventIDs := make(map[string]struct{}, len(events))
for _, event := range events {
if event.Type == "m.room.member" {
membershipEventIDs[event.ID] = struct{}{}
// Determine the NIDs in the snapshot which includes the new state events
var memberNIDs, otherNIDs []int64
if unknownRoom {
// Split the new NIDs into membership and nonmemberships.
membershipEventIDs := make(map[string]struct{}, len(events))
for _, event := range events {
if event.Type == "m.room.member" {
membershipEventIDs[event.ID] = struct{}{}
}
}
}
memberNIDs := make([]int64, 0, len(eventIDToNID))
otherNIDs := make([]int64, 0, len(eventIDToNID))
for evID, nid := range eventIDToNID {
if _, exists := membershipEventIDs[evID]; exists {
memberNIDs = append(memberNIDs, int64(nid))
} else {
otherNIDs = append(otherNIDs, int64(nid))
memberNIDs = make([]int64, 0, len(eventIDToNID))
otherNIDs = make([]int64, 0, len(eventIDToNID))
for evID, nid := range eventIDToNID {
if _, exists := membershipEventIDs[evID]; exists {
memberNIDs = append(memberNIDs, int64(nid))
} else {
otherNIDs = append(otherNIDs, int64(nid))
}
}
if err != nil {
return fmt.Errorf("failed to insert snapshot: %w", err)
}
} else {
// Update the existing snapshot, then extract NIDs.
stateEvents, err := a.strippedEventsForSnapshot(txn, snapshotID)
if err != nil {
return fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapshotID, err)
}
var newStripped StrippedEvents
for _, ev := range insertEvents {
stateEvents, _, err = a.calculateNewSnapshot(stateEvents, ev)
if err != nil {
return fmt.Errorf("failed to calculateNewSnapshot: %s", err)
}
}
memberNIDs, otherNIDs = newStripped.NIDs()
}

// Make a current snapshot
// Insert the new snapshot
snapshot := &SnapshotRow{
RoomID: roomID,
MembershipEvents: pq.Int64Array(memberNIDs),
Expand All @@ -220,6 +279,12 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (res In
return fmt.Errorf("failed to insert snapshot: %w", err)
}
res.AddedEvents = true
if !unknownRoom {
res.PrependTimelineEvents = make([]json.RawMessage, len(insertEvents))
for i := range insertEvents {
res.PrependTimelineEvents[i] = insertEvents[i].JSON
}
}
latestNID := int64(0)
for _, nid := range otherNIDs {
if nid > latestNID {
Expand All @@ -244,10 +309,10 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (res In
// will have an associated state snapshot ID on the event.

// Set the snapshot ID as the current state
res.snapID = snapshot.SnapshotID
res.SnapshotID = snapshot.SnapshotID
return a.roomsTable.Upsert(txn, info, snapshot.SnapshotID, latestNID)
})
return res, err
return res, outerErr
}

// Accumulate internal state from a user's sync response. The timeline order MUST be in the order
Expand Down
19 changes: 19 additions & 0 deletions state/event_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,25 @@ func (t *EventTable) SelectStrippedEventsByIDs(txn *sqlx.Tx, verifyAll bool, ids

}

// SelectUnknownEventIDs accepts a list of event IDs and returns the subset of those which are not known to the DB.
// The order of event IDs in the return value is not guaranteed.
func (t *EventTable) SelectUnknownEventIDs(txn *sqlx.Tx, maybeUnknownEventIDs []string) ([]string, error) {
queryStr := `
WITH maybe_unknown_events(event_id) AS (SELECT unnest($1::text[]))
SELECT event_id
FROM maybe_unknown_events LEFT JOIN syncv3_events USING(event_id)
WHERE event_nid IS NULL;`

var unknownEventIDs []string
var err error
if txn != nil {
err = txn.Select(&unknownEventIDs, queryStr, maybeUnknownEventIDs)
} else {
err = t.db.Select(&unknownEventIDs, queryStr, maybeUnknownEventIDs)
}
return unknownEventIDs, err
}

// UpdateBeforeSnapshotID sets the before_state_snapshot_id field to `snapID` for the given NIDs.
func (t *EventTable) UpdateBeforeSnapshotID(txn *sqlx.Tx, eventNID, snapID, replacesNID int64) error {
_, err := txn.Exec(
Expand Down

0 comments on commit accec63

Please sign in to comment.