Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
STUFF for current token
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Jan 13, 2023
1 parent c82945a commit 5731165
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 22 deletions.
2 changes: 1 addition & 1 deletion synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,7 +2178,7 @@ def send_presence_to_destinations(

self._notifier.notify_replication()

def get_current_token(self, instance_name: str) -> int:
def get_current_token(self, instance_name: str, minimum: bool = False) -> int:
"""Get the current position of the stream.
On workers this returns the last stream ID received from replication.
Expand Down
4 changes: 2 additions & 2 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:

data[_STREAM_POSITION_KEY] = {
"streams": {
stream.NAME: stream.current_token(local_instance_name)
stream.NAME: stream.current_token(local_instance_name, True)
for stream in streams
},
"instance_name": local_instance_name,
Expand Down Expand Up @@ -443,7 +443,7 @@ async def _check_auth_and_handle(

if self.WAIT_FOR_STREAMS:
response[_STREAM_POSITION_KEY] = {
stream.NAME: stream.current_token(self._instance_name)
stream.NAME: stream.current_token(self._instance_name, True)
for stream in self._streams
}

Expand Down
8 changes: 7 additions & 1 deletion synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,13 @@ async def wait_for_stream_position(

# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
logger.info(
"Waiting for repl stream %r to reach %s (%s) (current: %s)",
stream_name,
position,
instance_name,
current_position,
)
await make_deferred_yieldable(deferred)
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
Expand Down
8 changes: 7 additions & 1 deletion synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,13 @@ async def on_rdata(
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
logger.debug(
"%s: Received rdata %s (%s) -> %s",
self._instance_name,
stream_name,
instance_name,
token,
)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
Expand Down
3 changes: 2 additions & 1 deletion synapse/replication/tcp/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ async def _run_notifier_loop(self) -> None:

for stream in all_streams:
if stream.last_token == stream.current_token(
self._instance_name
self._instance_name,
minimum=stream.NAME == EventsStream.NAME,
):
continue

Expand Down
24 changes: 18 additions & 6 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

import attr
from typing_extensions import Protocol

from synapse.api.constants import AccountDataTypes
from synapse.replication.http.streams import ReplicationGetStreamUpdates
Expand Down Expand Up @@ -78,6 +79,11 @@
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]


class CurrentTokenFunction(Protocol):
def __call__(self, instance_name: str, minimum: bool = False) -> Token:
...


class Stream:
"""Base class for the streams.
Expand Down Expand Up @@ -107,7 +113,7 @@ def parse_row(cls, row: StreamRow) -> Any:
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[str], Token],
current_token_function: CurrentTokenFunction,
update_function: UpdateFunction,
):
"""Instantiate a Stream
Expand Down Expand Up @@ -192,12 +198,16 @@ async def get_updates_since(

def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
) -> CurrentTokenFunction:
"""Takes a current token callback function for a single writer stream
that doesn't take an instance name parameter and wraps it in a function that
does accept an instance name parameter but ignores it.
"""
return lambda instance_name: current_token()

def expanded_current_token(instance_name: str, minimum: bool = False) -> int:
return current_token()

return expanded_current_token


def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction:
Expand Down Expand Up @@ -246,10 +256,12 @@ def __init__(self, hs: "HomeServer"):
self.store.get_all_new_backfill_event_rows,
)

def _current_token(self, instance_name: str) -> int:
def _current_token(self, instance_name: str, minimum: bool = False) -> int:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
return -self.store._backfill_id_gen.get_current_token_for_writer(
instance_name, minimum
)


class PresenceStream(Stream):
Expand Down Expand Up @@ -395,7 +407,7 @@ def __init__(self, hs: "HomeServer"):
self.store.get_all_push_rule_updates,
)

def _current_token(self, instance_name: str) -> int:
def _current_token(self, instance_name: str, minimum: bool = False) -> int:
push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token

Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/tcp/streams/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs.get_instance_name(), current_token, update_function)

@staticmethod
def _stub_current_token(instance_name: str) -> int:
def _stub_current_token(instance_name: str, minimum: bool = False) -> int:
# dummy current-token method for use on workers
return 0

Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def process_replication_position(
) -> None:
if stream_name == CachesStream.NAME:
if self._cache_id_gen:
logger.info("Advancing cache for %s to %s", instance_name, token)
self._cache_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

Expand Down Expand Up @@ -402,8 +403,12 @@ def _send_invalidation_to_replication(
},
)

def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
def get_cache_stream_token_for_writer(
self, instance_name: str, minimum: bool = False
) -> int:
if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name)
return self._cache_id_gen.get_current_token_for_writer(
instance_name, minimum
)
else:
return 0
25 changes: 18 additions & 7 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def get_current_token(self) -> int:
raise NotImplementedError()

@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
def get_current_token_for_writer(
self, instance_name: str, minimum: bool = False
) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to `get_current_token`.
Expand Down Expand Up @@ -262,7 +264,9 @@ def get_current_token(self) -> int:

return self._current

def get_current_token_for_writer(self, instance_name: str) -> int:
def get_current_token_for_writer(
self, instance_name: str, minimum: bool = False
) -> int:
return self.get_current_token()


Expand Down Expand Up @@ -378,6 +382,8 @@ def __init__(
self._current_positions.values(), default=1
)

self._last_persisted_position = self._persisted_upto_position

def _load_current_ids(
self,
db_conn: LoggingDatabaseConnection,
Expand Down Expand Up @@ -627,24 +633,29 @@ def _mark_id_as_finished(self, next_id: int) -> None:
if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)
self._last_persisted_position = max(curr, new_cur)

self._add_persisted_position(next_id)

def get_current_token(self) -> int:
return self.get_persisted_upto_position()

def get_current_token_for_writer(self, instance_name: str) -> int:
def get_current_token_for_writer(
self, instance_name: str, minimum: bool = False
) -> int:
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
# For new writers we assume their initial position to be the current
# persisted up to position. This stops Synapse from doing a full table
# scan when a new writer announces itself over replication.
with self._lock:
return self._return_factor * max(
self._current_positions.get(instance_name, 0),
self._persisted_upto_position,
)
if minimum and instance_name == self._instance_name:
return self._last_persisted_position
else:
return self._return_factor * self._current_positions.get(
instance_name, self._persisted_upto_position
)

def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
Expand Down

0 comments on commit 5731165

Please sign in to comment.