Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix a typing issue in the federation client code found with mypy #7089

Merged
merged 2 commits into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7089.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix typing issue in federation client found with mypy.
clokep marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 9 additions & 15 deletions synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@

from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
Expand All @@ -55,13 +51,15 @@ def __init__(self, hs):
self.store = hs.get_datastore()
self._clock = hs.get_clock()

def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
def _check_sigs_and_hash(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not saying that we should undo it (it's certainly much cleaner now), but for the record: this seems to go beyond what is needed to fix the issue here. afaict only the calls to _check_sigs_and_hash_and_fetch needed changing; _check_sigs_and_hash, _check_sigs_and_hashes and _check_sigs_on_pdus were ok.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does go a bit beyond the minimum, but it seemed clearer / more efficient to not switch back and forth between RoomVersion and a str repeatedly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, fair.

self, room_version: RoomVersion, pdu: EventBase
) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)

def _check_sigs_and_hashes(
self, room_version: str, pdus: List[EventBase]
self, room_version: RoomVersion, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
Expand Down Expand Up @@ -146,7 +144,7 @@ class PduToCheckSig(


def _check_sigs_on_pdus(
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed

Expand Down Expand Up @@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
for p in pdus
]

v = KNOWN_ROOM_VERSIONS.get(room_version)
if not v:
raise RuntimeError("Unrecognized room version %s" % (room_version,))

# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
Expand All @@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
Expand All @@ -227,7 +221,7 @@ def sender_err(e, pdu_to_check):
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
if v.event_format == EventFormatVersions.V1:
if room_version.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p
for p in pdus_to_check
Expand All @@ -239,7 +233,7 @@ def sender_err(e, pdu_to_check):
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id
Expand Down
19 changes: 8 additions & 11 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ async def backfill(
# FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable(
defer.gatherResults(
self._check_sigs_and_hashes(room_version.identifier, pdus),
consumeErrors=True,
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
).addErrback(unwrapFirstError)
)

Expand Down Expand Up @@ -291,9 +290,7 @@ async def get_pdu(
pdu = pdu_list[0]

# Check signatures are correct.
signed_pdu = await self._check_sigs_and_hash(
room_version.identifier, pdu
)
signed_pdu = await self._check_sigs_and_hash(room_version, pdu)

break

Expand Down Expand Up @@ -350,7 +347,7 @@ async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: List[EventBase],
room_version: str,
room_version: RoomVersion,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
Expand Down Expand Up @@ -396,7 +393,7 @@ def handle_check_result(pdu: EventBase, deferred: Deferred):
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version, # type: ignore
room_version=room_version,
outlier=outlier,
timeout=10000,
)
Expand Down Expand Up @@ -434,7 +431,7 @@ async def get_event_auth(self, destination, room_id, event_id):
]

signed_auth = await self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version.identifier
destination, auth_chain, outlier=True, room_version=room_version
)

signed_auth.sort(key=lambda e: e.depth)
Expand Down Expand Up @@ -661,7 +658,7 @@ async def send_request(destination) -> Dict[str, Any]:
destination,
list(pdus.values()),
outlier=True,
room_version=room_version.identifier,
room_version=room_version,
)

valid_pdus_map = {p.event_id: p for p in valid_pdus}
Expand Down Expand Up @@ -756,7 +753,7 @@ async def send_invite(
pdu = event_from_pdu_json(pdu_dict, room_version)

# Check signatures are correct.
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)

# FIXME: We should handle signature failures more gracefully.

Expand Down Expand Up @@ -948,7 +945,7 @@ async def get_missing_events(
]

signed_events = await self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version.identifier
destination, events, outlier=False, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:
Expand Down
8 changes: 4 additions & 4 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ async def on_invite_request(
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
Expand All @@ -425,7 +425,7 @@ async def on_send_join_request(self, origin, content, room_id):

logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)

pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)

res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
Expand Down Expand Up @@ -455,7 +455,7 @@ async def on_send_leave_request(self, origin, content, room_id):

logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)

pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)

await self.handler.on_send_leave_request(origin, pdu)
return {}
Expand Down Expand Up @@ -611,7 +611,7 @@ async def _handle_received_pdu(self, origin, pdu):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)

# We've already checked that we know the room version by this point
room_version = await self.store.get_room_version_id(pdu.room_id)
room_version = await self.store.get_room_version(pdu.room_id)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One difference here is that get_room_version_id has an LRU cache in front of it while get_room_version does not. I'm unsure how much this matters in practice, but thought it was worth pointing out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really matter, because get_room_version is a thin wrapper around get_room_version_id; the slow bit (going to the db) is still cached.


# Check signature.
try:
Expand Down