Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix bug where a new writer advances their token too quickly #16473

Merged
merged 10 commits into from
Oct 23, 2023
1 change: 1 addition & 0 deletions changelog.d/16473.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing, exceedingly rare edge case where the first event persisted by a new event persister worker might not be sent down `/sync`.
2 changes: 1 addition & 1 deletion synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ async def send_request(

data[_STREAM_POSITION_KEY] = {
"streams": {
stream.NAME: stream.current_token(local_instance_name)
stream.NAME: stream.minimal_local_current_token()
for stream in streams
},
"instance_name": local_instance_name,
Expand Down
129 changes: 83 additions & 46 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.util.id_generators import AbstractStreamIdGenerator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,22 +108,10 @@ def parse_row(cls, row: StreamRow) -> Any:
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream

`current_token_function` and `update_function` are callbacks which
should be implemented by subclasses.

`current_token_function` takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).

`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
Expand All @@ -133,12 +122,28 @@ def __init__(
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function

# The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name)

def current_token(self, instance_name: str) -> Token:
"""This takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process).
"""
# We can't make this an abstract class as it makes mypy unhappy.
raise NotImplementedError()

def minimal_local_current_token(self) -> Token:
"""Tries to return a minimal current token for the local instance,
i.e. for writers this would be the last successful write.

If local instance is not a writer (or has written yet) then falls back
to returning the normal "current token".
"""
raise NotImplementedError()

def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
Expand Down Expand Up @@ -190,6 +195,25 @@ async def get_updates_since(
return updates, upto_token, limited


class _StreamFromIdGen(Stream):
"""Helper class for simple streams that use a stream ID generator"""

def __init__(
self,
local_instance_name: str,
update_function: UpdateFunction,
stream_id_gen: "AbstractStreamIdGenerator",
):
self._stream_id_gen = stream_id_gen
super().__init__(local_instance_name, update_function)

def current_token(self, instance_name: str) -> Token:
return self._stream_id_gen.get_current_token_for_writer(instance_name)

def minimal_local_current_token(self) -> Token:
return self._stream_id_gen.get_minimal_local_current_token()


def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
Expand Down Expand Up @@ -242,17 +266,21 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
self._current_token,
self.store.get_all_new_backfill_event_rows,
)

def _current_token(self, instance_name: str) -> int:
def current_token(self, instance_name: str) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)

def minimal_local_current_token(self) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_minimal_local_current_token()

class PresenceStream(Stream):

class PresenceStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow:
user_id: str
Expand Down Expand Up @@ -283,9 +311,7 @@ def __init__(self, hs: "HomeServer"):
update_function = make_http_update_function(hs, self.NAME)

super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_presence_token),
update_function,
hs.get_instance_name(), update_function, store._presence_id_gen
)


Expand All @@ -305,13 +331,18 @@ class PresenceFederationStreamRow:
ROW_TYPE = PresenceFederationStreamRow

def __init__(self, hs: "HomeServer"):
federation_queue = hs.get_presence_handler().get_federation_queue()
self._federation_queue = hs.get_presence_handler().get_federation_queue()
super().__init__(
hs.get_instance_name(),
federation_queue.get_current_token,
federation_queue.get_replication_rows,
self._federation_queue.get_replication_rows,
)

def current_token(self, instance_name: str) -> Token:
return self._federation_queue.get_current_token(instance_name)

def minimal_local_current_token(self) -> Token:
return self._federation_queue.get_current_token(self.local_instance_name)


class TypingStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand Down Expand Up @@ -341,20 +372,25 @@ def __init__(self, hs: "HomeServer"):
update_function: Callable[
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
] = typing_writer_handler.get_all_typing_updates
current_token_function = typing_writer_handler.get_current_token
self.current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
current_token_function = hs.get_typing_handler().get_current_token
self.current_token_function = hs.get_typing_handler().get_current_token

super().__init__(
hs.get_instance_name(),
current_token_without_instance(current_token_function),
update_function,
)

def current_token(self, instance_name: str) -> Token:
return self.current_token_function()

def minimal_local_current_token(self) -> Token:
return self.current_token_function()

class ReceiptsStream(Stream):

class ReceiptsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow:
room_id: str
Expand All @@ -371,12 +407,12 @@ def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_receipt_stream_id),
store.get_all_updated_receipts,
store._receipts_id_gen,
)


class PushRulesStream(Stream):
class PushRulesStream(_StreamFromIdGen):
"""A user has changed their push rules"""

@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand All @@ -387,20 +423,16 @@ class PushRulesStreamRow:
ROW_TYPE = PushRulesStreamRow

def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
store = hs.get_datastores().main

super().__init__(
hs.get_instance_name(),
self._current_token,
self.store.get_all_push_rule_updates,
store.get_all_push_rule_updates,
store._push_rules_stream_id_gen,
)

def _current_token(self, instance_name: str) -> int:
push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token


class PushersStream(Stream):
class PushersStream(_StreamFromIdGen):
"""A user has added/changed/removed a pusher"""

@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand All @@ -418,8 +450,8 @@ def __init__(self, hs: "HomeServer"):

super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
store.get_all_updated_pushers_rows,
store._pushers_id_gen,
)


Expand Down Expand Up @@ -447,15 +479,20 @@ class CachesStreamRow:
ROW_TYPE = CachesStreamRow

def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token_for_writer,
store.get_all_updated_caches,
self.store.get_all_updated_caches,
)

def current_token(self, instance_name: str) -> Token:
return self.store.get_cache_stream_token_for_writer(instance_name)

def minimal_local_current_token(self) -> Token:
return self.current_token(self.local_instance_name)


class DeviceListsStream(Stream):
class DeviceListsStream(_StreamFromIdGen):
"""Either a user has updated their devices or a remote server needs to be
told about a device update.
"""
Expand All @@ -473,8 +510,8 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_device_stream_token),
self._update_function,
self.store._device_list_id_gen,
)

async def _update_function(
Expand Down Expand Up @@ -525,7 +562,7 @@ async def _update_function(
return updates, upper_limit_token, devices_limited or signatures_limited


class ToDeviceStream(Stream):
class ToDeviceStream(_StreamFromIdGen):
"""New to_device messages for a client"""

@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand All @@ -539,12 +576,12 @@ def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
store.get_all_new_device_messages,
store._device_inbox_id_gen,
)


class AccountDataStream(Stream):
class AccountDataStream(_StreamFromIdGen):
"""Global or per room account data was changed"""

@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand All @@ -560,8 +597,8 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
self._update_function,
self.store._account_data_id_gen,
)

async def _update_function(
Expand Down
8 changes: 3 additions & 5 deletions synapse/replication/tcp/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import attr

from synapse.replication.tcp.streams._base import (
Stream,
StreamRow,
StreamUpdateResult,
Token,
_StreamFromIdGen,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,17 +119,15 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
TypeToRow = {Row.TypeId: Row for Row in _EventRows}


class EventsStream(Stream):
class EventsStream(_StreamFromIdGen):
"""We received a new event, or an event went from being an outlier to not"""

NAME = "events"

def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
hs.get_instance_name(), self._update_function, self._store._stream_id_gen
)

async def _update_function(
Expand Down
15 changes: 11 additions & 4 deletions synapse/replication/tcp/streams/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from synapse.replication.tcp.streams._base import (
Stream,
Token,
current_token_without_instance,
make_http_update_function,
)
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(self, hs: "HomeServer"):
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = current_token_without_instance(
self.current_token_func = current_token_without_instance(
federation_sender.get_current_token
)
update_function: Callable[
Expand All @@ -57,15 +58,21 @@ def __init__(self, hs: "HomeServer"):
elif hs.should_send_federation():
# federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME)
current_token = self._stub_current_token
self.current_token_func = self._stub_current_token

else:
# other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
current_token = self._stub_current_token
self.current_token_func = self._stub_current_token

super().__init__(hs.get_instance_name(), current_token, update_function)
super().__init__(hs.get_instance_name(), update_function)

def current_token(self, instance_name: str) -> Token:
return self.current_token_func(instance_name)

def minimal_local_current_token(self) -> Token:
return self.current_token(self.local_instance_name)

@staticmethod
def _stub_current_token(instance_name: str) -> int:
Expand Down
Loading
Loading