Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not make snapshots for lone leave events #235

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pubsub/v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ type V2AccountData struct {
func (*V2AccountData) Type() string { return "V2AccountData" }

type V2LeaveRoom struct {
UserID string
RoomID string
UserID string
RoomID string
LeaveEvent json.RawMessage
}

func (*V2LeaveRoom) Type() string { return "V2LeaveRoom" }
Expand Down
45 changes: 31 additions & 14 deletions state/accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It adds entries to the membership log for membership events.
func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
// we expect:
// - there to be no duplicate events
Expand All @@ -308,6 +308,36 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
return 0, nil, err // nothing to do
}

// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}

// if we have just got a leave event for the polling user, and there is no snapshot for this room already, then
// we do NOT want to add this event to the events table, nor do we want to make a room snapshot. This is because
// this leave event is an invite rejection, rather than a normal event. Invite rejections cannot be processed in
// a normal way because we lack room state (no create event, PLs, etc). If we were to process the invite rejection,
// the room state would just be a single event: this leave event, which is wrong.
if len(dedupedEvents) == 1 &&
dedupedEvents[0].Type == "m.room.member" &&
(dedupedEvents[0].Membership == "leave" || dedupedEvents[0].Membership == "_leave") &&
dedupedEvents[0].StateKey == userID &&
snapID == 0 {
logger.Info().Str("event_id", dedupedEvents[0].ID).Str("room_id", roomID).Str("user_id", userID).Err(err).Msg(
"Accumulator: skipping processing of leave event, as no snapshot exists",
)
return 0, nil, nil
}

eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return 0, nil, err
Expand Down Expand Up @@ -339,19 +369,6 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
}
}

// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}
for _, ev := range newEvents {
var replacesNID int64
// the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not,
Expand Down
14 changes: 9 additions & 5 deletions state/accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
"github.com/tidwall/gjson"
)

var (
userID = "@me:localhost"
)

func TestAccumulatorInitialise(t *testing.T) {
roomID := "!TestAccumulatorInitialise:localhost"
roomEvents := []json.RawMessage{
Expand Down Expand Up @@ -118,7 +122,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
var numNew int
var latestNIDs []int64
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, latestNIDs, err = accumulator.Accumulate(txn, roomID, "", newEvents)
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
Expand Down Expand Up @@ -192,7 +196,7 @@ func TestAccumulatorAccumulate(t *testing.T) {

// subsequent calls do nothing and are not an error
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", newEvents)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
Expand Down Expand Up @@ -228,7 +232,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
return err
})
if err != nil {
Expand Down Expand Up @@ -355,7 +359,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
}

err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", joinRoom.Timeline.Events)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
return err
})
if err != nil {
Expand Down Expand Up @@ -555,7 +559,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
defer wg.Done()
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, _, err := accumulator.Accumulate(txn, roomID, "", subset)
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += numNew
return err
})
Expand Down
4 changes: 2 additions & 2 deletions state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
return result, nil
}

func (s *Storage) Accumulate(roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
if len(timeline) == 0 {
return 0, nil, nil
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, roomID, prevBatch, timeline)
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
return err
})
return
Expand Down
12 changes: 6 additions & 6 deletions state/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
}
_, latestNIDs, err := store.Accumulate(roomID, "", events)
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate returned error: %s", err)
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
var latestNIDs []int64
var err error
for roomID, eventMap := range roomIDToEventMap {
_, latestNIDs, err = store.Accumulate(roomID, "", eventMap)
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
}
Expand Down Expand Up @@ -351,7 +351,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
},
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
Expand Down Expand Up @@ -454,7 +454,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
t.Fatalf("LatestEventNID: %s", err)
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
Expand Down Expand Up @@ -534,7 +534,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
}
eventIDs := []string{}
for _, timeline := range timelines {
_, _, err = store.Accumulate(roomID, timeline.prevBatch, timeline.timeline)
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
if err != nil {
t.Fatalf("failed to accumulate: %s", err)
}
Expand Down Expand Up @@ -776,7 +776,7 @@ func TestAllJoinedMembers(t *testing.T) {
}, serialise(tc.InitMemberships)...))
assertNoError(t, err)

_, _, err = store.Accumulate(roomID, "foo", serialise(tc.AccumulateMemberships))
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
assertNoError(t, err)
testCases[i].RoomID = roomID // remember this for later
}
Expand Down
10 changes: 6 additions & 4 deletions sync2/handler2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}

// Insert new events
numNew, latestNIDs, err := h.Store.Accumulate(roomID, prevBatch, timeline)
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
if err != nil {
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
Expand Down Expand Up @@ -448,16 +448,18 @@ func (h *Handler) OnInvite(ctx context.Context, userID, roomID string, inviteSta
})
}

func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) {
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEv json.RawMessage) {
// remove any invites for this user if they are rejecting an invite
err := h.Store.InvitesTable.RemoveInvite(userID, roomID)
if err != nil {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to retire invite")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
}

h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{
UserID: userID,
RoomID: roomID,
UserID: userID,
RoomID: roomID,
LeaveEvent: leaveEv,
})
}

Expand Down
21 changes: 16 additions & 5 deletions sync2/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type V2DataReceiver interface {
// Sent when there is a room in the `invite` section of the v2 response.
OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) // invitestate in db
// Sent when there is a room in the `leave` section of the v2 response.
OnLeftRoom(ctx context.Context, userID, roomID string)
OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage)
// Sent when there is a _change_ in E2EE data, not all the time
OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
// Sent when the poll loop terminates
Expand Down Expand Up @@ -301,11 +301,11 @@ func (h *PollerMap) OnInvite(ctx context.Context, userID, roomID string, inviteS
wg.Wait()
}

func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string) {
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.OnLeftRoom(ctx, userID, roomID)
h.callbacks.OnLeftRoom(ctx, userID, roomID, leaveEvent)
wg.Done()
}
wg.Wait()
Expand Down Expand Up @@ -716,12 +716,23 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
}
}
for roomID, roomData := range res.Rooms.Leave {
// TODO: do we care about state?
if len(roomData.Timeline.Events) > 0 {
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
p.receiver.Accumulate(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
}
p.receiver.OnLeftRoom(ctx, p.userID, roomID)
// Pass the leave event directly to OnLeftRoom. We need to do this _in addition_ to calling Accumulate to handle
// the case where a user rejects an invite (there will be no room state, but the user still expects to see the leave event).
var leaveEvent json.RawMessage
for _, ev := range roomData.Timeline.Events {
leaveEv := gjson.ParseBytes(ev)
if leaveEv.Get("content.membership").Str == "leave" && leaveEv.Get("state_key").Str == p.userID {
leaveEvent = ev
break
}
}
if leaveEvent != nil {
p.receiver.OnLeftRoom(ctx, p.userID, roomID, leaveEvent)
}
kegsay marked this conversation as resolved.
Show resolved Hide resolved
}
for roomID, roomData := range res.Rooms.Invite {
p.receiver.OnInvite(ctx, p.userID, roomID, roomData.InviteState.Events)
Expand Down
3 changes: 2 additions & 1 deletion sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ func (s *mockDataReceiver) OnReceipt(ctx context.Context, userID, roomID, ephEve
}
func (s *mockDataReceiver) OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) {
}
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string) {}
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
}
func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {}
Expand Down
4 changes: 2 additions & 2 deletions sync3/caches/global_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ func TestGlobalCacheLoadState(t *testing.T) {
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
}
_, _, err := store.Accumulate(roomID2, "", eventsRoom2)
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}

_, latestNIDs, err := store.Accumulate(roomID, "", events)
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
Expand Down
9 changes: 0 additions & 9 deletions sync3/caches/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,6 @@ func (u *InviteUpdate) Type() string {
return fmt.Sprintf("InviteUpdate[%s]", u.RoomID())
}

// LeftRoomUpdate corresponds to a key-value pair from a v2 sync's `leave` section.
type LeftRoomUpdate struct {
RoomUpdate
}

func (u *LeftRoomUpdate) Type() string {
return fmt.Sprintf("LeftRoomUpdate[%s]", u.RoomID())
}

// TypingEdu corresponds to a typing EDU in the `ephemeral` section of a joined room's v2 sync resposne.
type TypingUpdate struct {
RoomUpdate
Expand Down
19 changes: 17 additions & 2 deletions sync3/caches/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ func (c *UserCache) OnInvite(ctx context.Context, roomID string, inviteStateEven
c.emitOnRoomUpdate(ctx, up)
}

func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string, leaveEvent json.RawMessage) {
urd := c.LoadRoomData(roomID)
urd.IsInvite = false
urd.HasLeft = true
Expand All @@ -616,14 +616,29 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
c.roomToData[roomID] = urd
c.roomToDataMu.Unlock()

up := &LeftRoomUpdate{
ev := gjson.ParseBytes(leaveEvent)
stateKey := ev.Get("state_key").Str

up := &RoomEventUpdate{
RoomUpdate: &roomUpdateCache{
roomID: roomID,
// do NOT pull from the global cache as it is a snapshot of the room at the point of
// the invite: don't leak additional data!!!
globalRoomData: internal.NewRoomMetadata(roomID),
userRoomData: &urd,
},
EventData: &EventData{
Event: leaveEvent,
RoomID: roomID,
EventType: ev.Get("type").Str,
StateKey: &stateKey,
Content: ev.Get("content"),
Timestamp: ev.Get("origin_server_ts").Uint(),
Sender: ev.Get("sender").Str,
// if this is an invite rejection we need to make sure we tell the client, and not
// skip it because of the lack of a NID (this event may not be in the events table)
AlwaysProcess: true,
},
}
c.emitOnRoomUpdate(ctx, up)
}
Expand Down
10 changes: 6 additions & 4 deletions sync3/handler/connstate_live.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,13 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
if roomEventUpdate != nil && roomEventUpdate.EventData.Event != nil {
r.NumLive++
advancedPastEvent := false
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
// this update has been accounted for by the initial:true room snapshot
advancedPastEvent = true
if !roomEventUpdate.EventData.AlwaysProcess {
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
// this update has been accounted for by the initial:true room snapshot
advancedPastEvent = true
}
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
}
kegsay marked this conversation as resolved.
Show resolved Hide resolved
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
// we only append to the timeline if we haven't already got this event. This can happen when:
// - 2 live events for a room mid-connection
// - next request bumps a room from outside to inside the window
Expand Down
3 changes: 1 addition & 2 deletions sync3/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
// setupConnection associates this request with an existing connection or makes a new connection.
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.

func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, *internal.HandlerError) {
taskCtx, task := internal.StartTask(req.Context(), "setupConnection")
defer task.End()
Expand Down Expand Up @@ -697,7 +696,7 @@ func (h *SyncLiveHandler) OnLeftRoom(p *pubsub.V2LeaveRoom) {
if !ok {
return
}
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID)
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID, p.LeaveEvent)
}

func (h *SyncLiveHandler) OnReceipt(p *pubsub.V2Receipt) {
Expand Down
Loading
Loading