Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Some refactors around receipts stream #16426

Merged
merged 5 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16426.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor some code to simplify and better type receipts stream adjacent code.
4 changes: 2 additions & 2 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def handle_room_events(events: Iterable[EventBase]) -> None:

def notify_interested_services_ephemeral(
self,
stream_key: str,
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
Expand Down Expand Up @@ -326,7 +326,7 @@ def notify_interested_services_ephemeral(
async def _notify_interested_services_ephemeral(
self,
services: List[ApplicationService],
stream_key: str,
stream_key: StreamKeyType,
new_token: int,
users: Collection[Union[str, UserID]],
) -> None:
Expand Down
6 changes: 4 additions & 2 deletions synapse/handlers/push_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.push_rule import RuleNotFoundException
from synapse.synapse_rust.push import get_base_rule_ids
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, StreamKeyType, UserID

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -114,7 +114,9 @@ def notify_user(self, user_id: str) -> None:
user_id: the user ID the change is for.
"""
stream_id = self._main_store.get_max_push_rules_stream_id()
self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
self._notifier.on_new_event(
StreamKeyType.PUSH_RULES, stream_id, users=[user_id]
)

async def push_rules_for_user(
self, user: UserID
Expand Down
4 changes: 1 addition & 3 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
)
# Note that the min here shouldn't be relied upon to be accurate.
await self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
)
await self.hs.get_pusherpool().on_new_receipts({r.user_id for r in receipts})

return True

Expand Down
17 changes: 8 additions & 9 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(

def notify(
self,
stream_key: str,
stream_key: StreamKeyType,
stream_id: Union[int, RoomStreamToken],
time_now_ms: int,
) -> None:
Expand Down Expand Up @@ -454,7 +454,7 @@ def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken) -> None:

def on_new_event(
self,
stream_key: str,
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
Expand Down Expand Up @@ -655,30 +655,29 @@ async def check_for_updates(
events: List[Union[JsonDict, EventBase]] = []
end_token = from_token

for name, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)
for keyname, source in self.event_sources.sources.get_sources():
before_id = before_token.get_field(keyname)
after_id = after_token.get_field(keyname)
if before_id == after_id:
continue

new_events, new_key = await source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
from_key=from_token.get_field(keyname),
limit=limit,
is_guest=is_peeking,
room_ids=room_ids,
explicit_room_id=explicit_room_id,
)

if name == "room":
if keyname == StreamKeyType.ROOM:
new_events = await filter_events_for_client(
self._storage_controllers,
user.to_string(),
new_events,
is_peeking=is_peeking,
)
elif name == "presence":
elif keyname == StreamKeyType.PRESENCE:
now = self.clock.time_msec()
new_events[:] = [
{
Expand Down
2 changes: 1 addition & 1 deletion synapse/push/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _start_processing(self) -> None:
raise NotImplementedError()

@abc.abstractmethod
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
def on_new_receipts(self) -> None:
raise NotImplementedError()

@abc.abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion synapse/push/emailpusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def on_stop(self) -> None:
pass
self.timed_call = None

def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
def on_new_receipts(self) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
Expand Down
2 changes: 1 addition & 1 deletion synapse/push/httppusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def on_started(self, should_check_for_notifs: bool) -> None:
if should_check_for_notifs:
self._start_processing()

def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
def on_new_receipts(self) -> None:
# Note that the min here shouldn't be relied upon to be accurate.

# We could check the receipts are actually m.read receipts here,
Expand Down
12 changes: 2 additions & 10 deletions synapse/push/pusherpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,12 @@ async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
except Exception:
logger.exception("Exception in pusher on_new_notifications")

async def on_new_receipts(
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
) -> None:
async def on_new_receipts(self, users_affected: StrCollection) -> None:
if not self.pushers:
# nothing to do here.
return

try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
users_affected = await self.store.get_users_sent_receipts_between(
min_stream_id - 1, max_stream_id
)

for u in users_affected:
# Don't push if the user account has expired
expired = await self._account_validity_handler.is_user_expired(u)
Expand All @@ -314,7 +306,7 @@ async def on_new_receipts(

if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id)
p.on_new_receipts()

except Exception:
logger.exception("Exception in pusher on_new_receipts")
Expand Down
4 changes: 1 addition & 3 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ async def on_rdata(
self.notifier.on_new_event(
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
)
await self._pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def add_e2e_room_keys(
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
StreamKeyType.ROOM: room_key,
StreamKeyType.ROOM.value: room_key,
}
)

Expand Down
15 changes: 10 additions & 5 deletions synapse/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Iterator, Tuple
from typing import TYPE_CHECKING, Sequence, Tuple

import attr

Expand All @@ -23,7 +23,7 @@
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
from synapse.types import StreamToken
from synapse.types import StreamKeyType, StreamToken

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand All @@ -37,9 +37,14 @@ class _EventSourcesInner:
receipt: ReceiptEventSource
account_data: AccountDataEventSource

def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
for attribute in attr.fields(_EventSourcesInner):
yield attribute.name, getattr(self, attribute.name)
def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]:
return [
(StreamKeyType.ROOM, self.room),
(StreamKeyType.PRESENCE, self.presence),
(StreamKeyType.TYPING, self.typing),
(StreamKeyType.RECEIPT, self.receipt),
(StreamKeyType.ACCOUNT_DATA, self.account_data),
]


class EventSources:
Expand Down
59 changes: 44 additions & 15 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
Any,
ClassVar,
Dict,
Final,
List,
Literal,
Mapping,
Match,
MutableMapping,
Expand All @@ -34,6 +34,7 @@
Type,
TypeVar,
Union,
overload,
)

import attr
Expand Down Expand Up @@ -649,20 +650,20 @@ async def to_string(self, store: "DataStore") -> str:
return "s%d" % (self.stream,)


class StreamKeyType:
class StreamKeyType(Enum):
"""Known stream types.

A stream is a list of entities ordered by an incrementing "stream token".
"""

ROOM: Final = "room_key"
PRESENCE: Final = "presence_key"
TYPING: Final = "typing_key"
RECEIPT: Final = "receipt_key"
ACCOUNT_DATA: Final = "account_data_key"
PUSH_RULES: Final = "push_rules_key"
TO_DEVICE: Final = "to_device_key"
DEVICE_LIST: Final = "device_list_key"
ROOM = "room_key"
PRESENCE = "presence_key"
TYPING = "typing_key"
RECEIPT = "receipt_key"
ACCOUNT_DATA = "account_data_key"
PUSH_RULES = "push_rules_key"
TO_DEVICE = "to_device_key"
DEVICE_LIST = "device_list_key"
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"


Expand Down Expand Up @@ -784,7 +785,7 @@ async def to_string(self, store: "DataStore") -> str:
def room_stream_id(self) -> int:
return self.room_key.stream

def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.

Expand All @@ -797,16 +798,44 @@ def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
return new_token

new_token = self.copy_and_replace(key, new_value)
new_id = int(getattr(new_token, key))
old_id = int(getattr(self, key))
new_id = new_token.get_field(key)
old_id = self.get_field(key)

if old_id < new_id:
return new_token
else:
return self

def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key: new_value})
def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key.value: new_value})

@overload
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken:
...

@overload
def get_field(
self,
key: Literal[
StreamKeyType.ACCOUNT_DATA,
StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES,
StreamKeyType.RECEIPT,
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
],
) -> int:
...

@overload
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
...
clokep marked this conversation as resolved.
Show resolved Hide resolved

def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
"""Returns the stream ID for the given key."""
return getattr(self, key.value)


StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
Expand Down
8 changes: 4 additions & 4 deletions tests/handlers/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
from synapse.util import Clock
from synapse.util.stringutils import random_string

Expand Down Expand Up @@ -304,7 +304,7 @@ def test_notify_interested_services_ephemeral(self) -> None:
)

self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[event]
Expand Down Expand Up @@ -332,7 +332,7 @@ def test_notify_interested_services_ephemeral_out_of_order(self) -> None:
)

self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
)
# This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
Expand Down Expand Up @@ -634,7 +634,7 @@ def test_sending_read_receipt_batches_to_application_services(self) -> None:
self.get_success(
self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
services=[interested_appservice],
stream_key="receipt_key",
stream_key=StreamKeyType.RECEIPT,
new_token=stream_token,
users=[self.exclusive_as_user],
)
Expand Down
Loading
Loading