Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Parse msc3706 fields in send_join response
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Feb 16, 2022
1 parent 73fc488 commit 6f6c5c6
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 33 deletions.
1 change: 1 addition & 0 deletions changelog.d/12011.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: parse msc3706 fields in send_join response.
4 changes: 4 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ def read_config(self, config: JsonDict, **kwargs):

# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

# experimental support for faster joins over federation (msc2775, msc3706)
# requires a target server with msc3706_enabled enabled.
self.faster_joins_enabled: bool = experimental.get("faster_joins", False)
15 changes: 14 additions & 1 deletion synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -89,6 +89,12 @@ class SendJoinResult:
state: List[EventBase]
auth_chain: List[EventBase]

# True if 'state' elides non-critical membership events
partial_state: bool

# if 'partial_state' is set, a list of the servers in the room (otherwise empty)
servers_in_room: List[str]


class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"):
Expand Down Expand Up @@ -876,11 +882,18 @@ async def _execute(pdu: EventBase) -> None:
% (auth_chain_create_events,)
)

if response.partial_state and not response.servers_in_room:
raise InvalidResponseError(
"partial_state was set, but no servers were listed in the room"
)

return SendJoinResult(
event=event,
state=signed_state,
auth_chain=signed_auth,
origin=destination,
partial_state=response.partial_state,
servers_in_room=response.servers_in_room or [],
)

# MSC3083 defines additional error codes for room joins.
Expand Down
124 changes: 92 additions & 32 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,6 +14,7 @@
# limitations under the License.

import logging
import typing
import urllib
from typing import (
Any,
Expand Down Expand Up @@ -46,6 +47,9 @@
from synapse.http.matrixfederationclient import ByteParser
from synapse.types import JsonDict

if typing.TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

# Send join responses can be huge, so we set a separate limit here. The response
Expand All @@ -57,9 +61,10 @@
class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled

async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
Expand Down Expand Up @@ -336,10 +341,15 @@ async def send_join_v2(
content: JsonDict,
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
query_params: Dict[str, str] = {}
if self._faster_joins_enabled:
# lazy-load state on join
query_params["org.matrix.msc3706.partial_state"] = "true"

return await self.client.put_json(
destination=destination,
path=path,
args=query_params,
data=content,
parser=SendJoinParser(room_version, v1_api=False),
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
Expand Down Expand Up @@ -1271,6 +1281,12 @@ class SendJoinResponse:
# "event" is not included in the response.
event: Optional[EventBase] = None

# The room state is incomplete
partial_state: bool = False

# List of servers in the room
servers_in_room: Optional[List[str]] = None


@ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
Expand All @@ -1297,6 +1313,32 @@ def _event_list_parser(
events.append(event)


@ijson.coroutine
def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
"""Helper function for use with `ijson.items_coro`
Parses the partial_state field in send_join responses
"""
while True:
val = yield
if not isinstance(val, bool):
raise TypeError("partial_state must be a boolean")
response.partial_state = val


@ijson.coroutine
def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
"""Helper function for use with `ijson.items_coro`
Parses the servers_in_room field in send_join responses
"""
while True:
val = yield
if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
raise TypeError("servers_in_room must be a list of strings")
response.servers_in_room = val


class SendJoinParser(ByteParser[SendJoinResponse]):
"""A parser for the response to `/send_join` requests.
Expand All @@ -1308,44 +1350,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json"

def __init__(self, room_version: RoomVersion, v1_api: bool):
self._response = SendJoinResponse([], [], {})
self._response = SendJoinResponse([], [], event_dict={})
self._room_version = room_version
self._coros = []

# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
prefix = "item." if v1_api else ""

self._coro_state = ijson.items_coro(
_event_list_parser(room_version, self._response.state),
prefix + "state.item",
use_float=True,
)
self._coro_auth = ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
use_float=True,
)
# TODO Remove the unstable prefix when servers have updated.
#
# By re-using the same event dictionary this will cause the parsing of
# org.matrix.msc3083.v2.event and event to stomp over each other.
# Generally this should be fine.
self._coro_unstable_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
use_float=True,
)
self._coro_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "event",
use_float=True,
)
self._coros = [
ijson.items_coro(
_event_list_parser(room_version, self._response.state),
prefix + "state.item",
use_float=True,
),
ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
use_float=True,
),
# TODO Remove the unstable prefix when servers have updated.
#
# By re-using the same event dictionary this will cause the parsing of
# org.matrix.msc3083.v2.event and event to stomp over each other.
# Generally this should be fine.
ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
use_float=True,
),
ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "event",
use_float=True,
),
]

if not v1_api:
self._coros.append(
ijson.items_coro(
_partial_state_parser(self._response),
"org.matrix.msc3706.partial_state",
use_float="True",
)
)

self._coros.append(
ijson.items_coro(
_servers_in_room_parser(self._response),
"org.matrix.msc3706.servers_in_room",
use_float="True",
)
)

def write(self, data: bytes) -> int:
self._coro_state.send(data)
self._coro_auth.send(data)
self._coro_unstable_event.send(data)
self._coro_event.send(data)
for c in self._coros:
c.send(data)

return len(data)

Expand Down
32 changes: 32 additions & 0 deletions tests/federation/transport/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,35 @@ def test_two_writes(self) -> None:
self.assertEqual(len(parsed_response.state), 1, parsed_response)
self.assertEqual(parsed_response.event_dict, {}, parsed_response)
self.assertIsNone(parsed_response.event, parsed_response)
self.assertFalse(parsed_response.partial_state, parsed_response)
self.assertEqual(parsed_response.servers_in_room, None, parsed_response)

def test_lazy_load(self) -> None:
"""Check that the partial_state flag is correctly parsed"""
parser = SendJoinParser(RoomVersions.V1, False)
response = {
"org.matrix.msc3706.partial_state": True,
}

serialised_response = json.dumps(response).encode()

# Send data to the parser
parser.write(serialised_response)

# Retrieve and check the parsed SendJoinResponse
parsed_response = parser.finish()
self.assertTrue(parsed_response.partial_state)

def test_servers_in_room(self) -> None:
"""Check that the servers_in_room field is correctly parsed"""
parser = SendJoinParser(RoomVersions.V1, False)
response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}

serialised_response = json.dumps(response).encode()

# Send data to the parser
parser.write(serialised_response)

# Retrieve and check the parsed SendJoinResponse
parsed_response = parser.finish()
self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])

0 comments on commit 6f6c5c6

Please sign in to comment.