Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Accept & store thread IDs for receipts (implement MSC3771). (#13782)
Browse files Browse the repository at this point in the history
Updates the `/receipts` endpoint and receipt EDU handler to parse a
`thread_id` from the body and insert it in the database.
  • Loading branch information
clokep authored Sep 23, 2022
1 parent 03c2bfb commit efd108b
Show file tree
Hide file tree
Showing 17 changed files with 173 additions and 41 deletions.
1 change: 1 addition & 0 deletions changelog.d/13782.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
2 changes: 2 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# MSC3786 (Add a default push rule to ignore m.room.server_acl events)
self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)

# MSC3771: Thread read receipts
self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False)
# MSC3772: A push rule for mutual relations.
self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)

Expand Down
23 changes: 21 additions & 2 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(self, hs: "HomeServer"):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()

self._msc3771_enabled = hs.config.experimental.msc3771_enabled

async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS."""
receipts = []
Expand Down Expand Up @@ -91,13 +93,23 @@ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None
)
continue

# Check if these receipts apply to a thread.
thread_id = None
data = user_values.get("data", {})
if self._msc3771_enabled and isinstance(data, dict):
thread_id = data.get("thread_id")
# If the thread ID is invalid, consider it missing.
if not isinstance(thread_id, str):
thread_id = None

receipts.append(
ReadReceipt(
room_id=room_id,
receipt_type=receipt_type,
user_id=user_id,
event_ids=user_values["event_ids"],
data=user_values.get("data", {}),
thread_id=thread_id,
data=data,
)
)

Expand All @@ -114,6 +126,7 @@ async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
receipt.receipt_type,
receipt.user_id,
receipt.event_ids,
receipt.thread_id,
receipt.data,
)

Expand Down Expand Up @@ -146,7 +159,12 @@ async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
return True

async def received_client_receipt(
self, room_id: str, receipt_type: str, user_id: str, event_id: str
self,
room_id: str,
receipt_type: str,
user_id: str,
event_id: str,
thread_id: Optional[str],
) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
Expand All @@ -156,6 +174,7 @@ async def received_client_receipt(
receipt_type=receipt_type,
user_id=user_id,
event_ids=[event_id],
thread_id=thread_id,
data={"ts": int(self.clock.time_msec())},
)

Expand Down
3 changes: 2 additions & 1 deletion synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ async def _on_new_receipts(
receipt.receipt_type,
receipt.user_id,
[receipt.event_id],
receipt.data,
thread_id=receipt.thread_id,
data=receipt.data,
)
await self.federation_sender.send_read_receipt(receipt_info)

Expand Down
1 change: 1 addition & 0 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ class ReceiptsStreamRow:
receipt_type: str
user_id: str
event_id: str
thread_id: Optional[str]
data: dict

NAME = "receipts"
Expand Down
2 changes: 2 additions & 0 deletions synapse/rest/client/read_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ async def on_POST(
receipt_type,
user_id=requester.user.to_string(),
event_id=event_id,
# Setting the thread ID is not possible with the /read_markers endpoint.
thread_id=None,
)

return 200, {}
Expand Down
14 changes: 13 additions & 1 deletion synapse/rest/client/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, hs: "HomeServer"):
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.FULLY_READ,
}
self._msc3771_enabled = hs.config.experimental.msc3771_enabled

async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
Expand All @@ -61,7 +62,17 @@ async def on_POST(
f"Receipt type must be {', '.join(self._known_receipt_types)}",
)

parse_json_object_from_request(request, allow_empty_body=False)
body = parse_json_object_from_request(request)

# Pull the thread ID, if one exists.
thread_id = None
if self._msc3771_enabled:
if "thread_id" in body:
thread_id = body.get("thread_id")
if not thread_id or not isinstance(thread_id, str):
raise SynapseError(
400, "thread_id field must be a non-empty string"
)

await self.presence_handler.bump_presence_active_time(requester.user)

Expand All @@ -77,6 +88,7 @@ async def on_POST(
receipt_type,
user_id=requester.user.to_string(),
event_id=event_id,
thread_id=thread_id,
)

return 200, {}
Expand Down
2 changes: 2 additions & 0 deletions synapse/rest/client/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
"org.matrix.msc3030": self.config.experimental.msc3030_enabled,
# Adds support for thread relations, per MSC3440.
"org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above
# Support for thread read receipts.
"org.matrix.msc3771": self.config.experimental.msc3771_enabled,
# Allows moderators to fetch redacted event content as described in MSC2815
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
# Adds support for login token requests as per MSC3882
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
"local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
"remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
"event_push_summary": "event_push_summary_unique_index",
"receipts_linearized": "receipts_linearized_unique_index",
"receipts_graph": "receipts_graph_unique_index",
}


Expand Down
87 changes: 64 additions & 23 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,9 @@ def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:

async def get_all_updated_receipts(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool
]:
"""Get updates for receipts replication stream.
Args:
Expand All @@ -567,9 +569,13 @@ async def get_all_updated_receipts(

def get_all_updated_receipts_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
int,
bool,
]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
Expand All @@ -578,8 +584,8 @@ def get_all_updated_receipts_txn(
txn.execute(sql, (last_id, current_id, limit))

updates = cast(
List[Tuple[int, list]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
[(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn],
)

limited = False
Expand Down Expand Up @@ -631,6 +637,7 @@ def _insert_linearized_receipt_txn(
receipt_type: str,
user_id: str,
event_id: str,
thread_id: Optional[str],
data: JsonDict,
stream_id: int,
) -> Optional[int]:
Expand All @@ -657,12 +664,27 @@ def _insert_linearized_receipt_txn(
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
if stream_ordering is not None:
sql = (
"SELECT stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
if thread_id is None:
thread_clause = "r.thread_id IS NULL"
thread_args: Tuple[str, ...] = ()
else:
thread_clause = "r.thread_id = ?"
thread_args = (thread_id,)

sql = f"""
SELECT stream_ordering, event_id FROM events
INNER JOIN receipts_linearized AS r USING (event_id, room_id)
WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause}
"""
txn.execute(
sql,
(
room_id,
receipt_type,
user_id,
)
+ thread_args,
)
txn.execute(sql, (room_id, receipt_type, user_id))

for so, eid in txn:
if int(so) >= stream_ordering:
Expand All @@ -682,21 +704,28 @@ def _insert_linearized_receipt_txn(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)

keyvalues = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
where_clause = ""
if thread_id is None:
where_clause = "thread_id IS NULL"
else:
keyvalues["thread_id"] = thread_id

self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
keyvalues=keyvalues,
values={
"stream_id": stream_id,
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
"thread_id": None,
},
where_clause=where_clause,
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,
Expand Down Expand Up @@ -748,6 +777,7 @@ async def insert_receipt(
receipt_type: str,
user_id: str,
event_ids: List[str],
thread_id: Optional[str],
data: dict,
) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Expand Down Expand Up @@ -780,6 +810,7 @@ async def insert_receipt(
receipt_type,
user_id,
linearized_event_id,
thread_id,
data,
stream_id=stream_id,
# Read committed is actually beneficial here because we check for a receipt with
Expand All @@ -794,7 +825,8 @@ async def insert_receipt(

now = self._clock.time_msec()
logger.debug(
"RR for event %s in %s (%i ms old)",
"Receipt %s for event %s in %s (%i ms old)",
receipt_type,
linearized_event_id,
room_id,
now - event_ts,
Expand All @@ -807,6 +839,7 @@ async def insert_receipt(
receipt_type,
user_id,
event_ids,
thread_id,
data,
)

Expand All @@ -821,6 +854,7 @@ def _insert_graph_receipt_txn(
receipt_type: str,
user_id: str,
event_ids: List[str],
thread_id: Optional[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts
Expand All @@ -832,19 +866,26 @@ def _insert_graph_receipt_txn(
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))

keyvalues = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
where_clause = ""
if thread_id is None:
where_clause = "thread_id IS NULL"
else:
keyvalues["thread_id"] = thread_id

self.db_pool.simple_upsert_txn(
txn,
table="receipts_graph",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
keyvalues=keyvalues,
values={
"event_ids": json_encoder.encode(event_ids),
"data": json_encoder.encode(data),
"thread_id": None,
},
where_clause=where_clause,
# receipts_graph has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,
Expand Down
1 change: 1 addition & 0 deletions synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ class ReadReceipt:
receipt_type: str
user_id: str
event_ids: List[str]
thread_id: Optional[str]
data: JsonDict


Expand Down
21 changes: 18 additions & 3 deletions tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def test_send_receipts(self):

sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
"room_id",
"m.read",
"user_id",
["event_id"],
thread_id=None,
data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))

Expand Down Expand Up @@ -89,7 +94,12 @@ def test_send_receipts_with_backoff(self):

sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
"room_id",
"m.read",
"user_id",
["event_id"],
thread_id=None,
data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))

Expand Down Expand Up @@ -121,7 +131,12 @@ def test_send_receipts_with_backoff(self):

# send the second RR
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
"room_id",
"m.read",
"user_id",
["other_id"],
thread_id=None,
data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
Expand Down
Loading

0 comments on commit efd108b

Please sign in to comment.