Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use device_one_time_keys_count to match MSC3202 #14565

Merged
merged 3 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14565.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
In transactions that include the experimental `org.matrix.msc3202.device_one_time_keys_count` key, include a duplicate key of `org.matrix.msc3202.device_one_time_keys_count` to match the name proposed by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/blob/travis/msc/otk-dl-appservice/proposals/3202-encrypted-appservices.md).
AndrewFerr marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 5 additions & 5 deletions synapse/appservice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@

logger = logging.getLogger(__name__)

# Type for the `device_one_time_key_counts` field in an appservice transaction
# Type for the `device_one_time_keys_count` field in an appservice transaction
# user ID -> {device ID -> {algorithm -> count}}
TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]]
TransactionOneTimeKeysCount = Dict[str, Dict[str, Dict[str, int]]]

# Type for the `device_unused_fallback_key_types` field in an appservice transaction
# user ID -> {device ID -> [algorithm]}
Expand Down Expand Up @@ -376,7 +376,7 @@ def __init__(
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
):
Expand All @@ -385,7 +385,7 @@ def __init__(
self.events = events
self.ephemeral = ephemeral
self.to_device_messages = to_device_messages
self.one_time_key_counts = one_time_key_counts
self.one_time_keys_count = one_time_keys_count
self.unused_fallback_keys = unused_fallback_keys
self.device_list_summary = device_list_summary

Expand All @@ -402,7 +402,7 @@ async def send(self, as_api: "ApplicationServiceApi") -> bool:
events=self.events,
ephemeral=self.ephemeral,
to_device_messages=self.to_device_messages,
one_time_key_counts=self.one_time_key_counts,
one_time_keys_count=self.one_time_keys_count,
unused_fallback_keys=self.unused_fallback_keys,
device_list_summary=self.device_list_summary,
txn_id=self.id,
Expand Down
11 changes: 7 additions & 4 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from synapse.api.errors import CodeMessageException
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.events import EventBase
Expand Down Expand Up @@ -262,7 +262,7 @@ async def push_bulk(
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
txn_id: Optional[int] = None,
Expand Down Expand Up @@ -310,10 +310,13 @@ async def push_bulk(

# TODO: Update to stable prefixes once MSC3202 completes FCP merge
if service.msc3202_transaction_extensions:
if one_time_key_counts:
if one_time_keys_count:
body[
"org.matrix.msc3202.device_one_time_key_counts"
] = one_time_key_counts
] = one_time_keys_count
body[
"org.matrix.msc3202.device_one_time_keys_count"
] = one_time_keys_count
if unused_fallback_keys:
body[
"org.matrix.msc3202.device_unused_fallback_key_types"
Expand Down
16 changes: 8 additions & 8 deletions synapse/appservice/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
TransactionOneTimeKeyCounts,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.appservice.api import ApplicationServiceApi
Expand Down Expand Up @@ -258,7 +258,7 @@ async def _send_request(self, service: ApplicationService) -> None:
):
return

one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None

if (
Expand All @@ -269,7 +269,7 @@ async def _send_request(self, service: ApplicationService) -> None:
# for the users which are mentioned in this transaction,
# as well as the appservice's sender.
(
one_time_key_counts,
one_time_keys_count,
unused_fallback_keys,
) = await self._compute_msc3202_otk_counts_and_fallback_keys(
service, events, ephemeral, to_device_messages_to_send
Expand All @@ -281,7 +281,7 @@ async def _send_request(self, service: ApplicationService) -> None:
events,
ephemeral,
to_device_messages_to_send,
one_time_key_counts,
one_time_keys_count,
unused_fallback_keys,
device_list_summary,
)
Expand All @@ -296,7 +296,7 @@ async def _compute_msc3202_otk_counts_and_fallback_keys(
events: Iterable[EventBase],
ephemerals: Iterable[JsonDict],
to_device_messages: Iterable[JsonDict],
) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]:
) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]:
"""
Given a list of the events, ephemeral messages and to-device messages,
- first computes a list of application services users that may have
Expand Down Expand Up @@ -367,7 +367,7 @@ async def send(
events: List[EventBase],
ephemeral: Optional[List[JsonDict]] = None,
to_device_messages: Optional[List[JsonDict]] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
) -> None:
Expand All @@ -380,7 +380,7 @@ async def send(
events: The persistent events to include in the transaction.
ephemeral: The ephemeral events to include in the transaction.
to_device_messages: The to-device messages to include in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
one_time_keys_count: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
Expand All @@ -397,7 +397,7 @@ async def send(
events=events,
ephemeral=ephemeral or [],
to_device_messages=to_device_messages or [],
one_time_key_counts=one_time_key_counts or {},
one_time_keys_count=one_time_keys_count or {},
unused_fallback_keys=unused_fallback_keys or {},
device_list_summary=device_list_summary or DeviceListUpdates(),
)
Expand Down
6 changes: 5 additions & 1 deletion synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,11 @@ async def upload_keys_for_user(
result = await self.store.count_e2e_one_time_keys(user_id, device_id)

set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
set_tag("one_time_keys_count", str(result))
AndrewFerr marked this conversation as resolved.
Show resolved Hide resolved
return {
"one_time_key_counts": result,
"one_time_keys_count": result,
}
AndrewFerr marked this conversation as resolved.
Show resolved Hide resolved

async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,14 +1426,14 @@ async def generate_sync_result(

logger.debug("Fetching OTK data")
device_id = sync_config.device_id
one_time_key_counts: JsonDict = {}
one_time_keys_count: JsonDict = {}
unused_fallback_key_types: List[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
one_time_key_counts = await self.store.count_e2e_one_time_keys(
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = (
Expand Down Expand Up @@ -1463,7 +1463,7 @@ async def generate_sync_result(
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
device_one_time_keys_count=one_time_key_counts,
device_one_time_keys_count=one_time_keys_count,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class KeyUploadServlet(RestServlet):
response, e.g.:

{
"one_time_key_counts": {
"one_time_keys_count": {
AndrewFerr marked this conversation as resolved.
Show resolved Hide resolved
"curve25519": 10,
"signed_curve25519": 20
}
Expand Down
10 changes: 5 additions & 5 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
TransactionOneTimeKeyCounts,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.config.appservice import load_appservices
Expand Down Expand Up @@ -260,7 +260,7 @@ async def create_appservice_txn(
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction:
Expand All @@ -273,7 +273,7 @@ async def create_appservice_txn(
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
one_time_keys_count: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
Expand All @@ -299,7 +299,7 @@ def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
events=events,
ephemeral=ephemeral,
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
one_time_keys_count=one_time_keys_count,
unused_fallback_keys=unused_fallback_keys,
device_list_summary=device_list_summary,
)
Expand Down Expand Up @@ -379,7 +379,7 @@ def _get_oldest_unsent_txn(
events=events,
ephemeral=[],
to_device_messages=[],
one_time_key_counts={},
one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
TransactionOneTimeKeyCounts,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
Expand Down Expand Up @@ -514,7 +514,7 @@ def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:

async def count_bulk_e2e_one_time_keys_for_as(
self, user_ids: Collection[str]
) -> TransactionOneTimeKeyCounts:
) -> TransactionOneTimeKeysCount:
"""
Counts, in bulk, the one-time keys for all the users specified.
Intended to be used by application services for populating OTK counts in
Expand All @@ -528,7 +528,7 @@ async def count_bulk_e2e_one_time_keys_for_as(

def _count_bulk_e2e_one_time_keys_txn(
txn: LoggingTransaction,
) -> TransactionOneTimeKeyCounts:
) -> TransactionOneTimeKeysCount:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
Expand All @@ -541,7 +541,7 @@ def _count_bulk_e2e_one_time_keys_txn(
"""
txn.execute(sql, user_parameters)

result: TransactionOneTimeKeyCounts = {}
result: TransactionOneTimeKeysCount = {}

for user_id, device_id, algorithm, count in txn:
# We deliberately construct empty dictionaries for
Expand Down
6 changes: 3 additions & 3 deletions tests/appservice/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_single_service_up_txn_sent(self):
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
one_time_key_counts={},
one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
Expand All @@ -96,7 +96,7 @@ def test_single_service_down(self):
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
one_time_key_counts={},
one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_single_service_up_txn_not_sent(self):
events=events,
ephemeral=[],
to_device_messages=[],
one_time_key_counts={},
one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
Expand Down
4 changes: 2 additions & 2 deletions tests/handlers/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.handlers.appservice import ApplicationServicesHandler
Expand Down Expand Up @@ -1123,7 +1123,7 @@ def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pd
# Capture what was sent as an AS transaction.
self.send_mock.assert_called()
last_args, _last_kwargs = self.send_mock.call_args
otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS]
otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS]
unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[
self.ARG_FALLBACK_KEYS
]
Expand Down
24 changes: 20 additions & 4 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def test_reupload_one_time_keys(self) -> None:
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
res,
{
"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
"one_time_keys_count": {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
},
)

# we should be able to change the signature without a problem
Expand All @@ -74,7 +78,11 @@ def test_reupload_one_time_keys(self) -> None:
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
res,
{
"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
"one_time_keys_count": {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
},
)

def test_change_one_time_keys(self) -> None:
Expand All @@ -94,7 +102,11 @@ def test_change_one_time_keys(self) -> None:
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
res,
{
"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
"one_time_keys_count": {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
},
)

# Error when changing string key
Expand Down Expand Up @@ -148,7 +160,11 @@ def test_claim_one_time_key(self) -> None:
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
res,
{
"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0},
"one_time_keys_count": {"alg1": 1, "signed_curve25519": 0},
},
AndrewFerr marked this conversation as resolved.
Show resolved Hide resolved
)

res2 = self.get_success(
Expand Down