Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add some type hints to datastore (#12717)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel authored May 17, 2022
1 parent 942c30b commit 6edefef
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 161 deletions.
1 change: 1 addition & 0 deletions changelog.d/12717.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some type hints to datastore.
2 changes: 0 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
24 changes: 17 additions & 7 deletions synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
import abc
import logging
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
)

import attr
from prometheus_client import Counter
Expand Down Expand Up @@ -409,7 +419,7 @@ async def handle_event(event: EventBase) -> None:
)
return

destinations: Optional[Set[str]] = None
destinations: Optional[Collection[str]] = None
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
Expand Down Expand Up @@ -444,7 +454,7 @@ async def handle_event(event: EventBase) -> None:
)
return

destinations = {
sharded_destinations = {
d
for d in destinations
if self._federation_shard_config.should_handle(
Expand All @@ -456,12 +466,12 @@ async def handle_event(event: EventBase) -> None:
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to
# send the event to it.
destinations.discard(send_on_behalf_of)
sharded_destinations.discard(send_on_behalf_of)

logger.debug("Sending %s to %r", event, destinations)
logger.debug("Sending %s to %r", event, sharded_destinations)

if destinations:
await self._send_pdu(event, destinations)
if sharded_destinations:
await self._send_pdu(event, sharded_destinations)

now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,10 @@ async def current_sync_for_user(
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result

async def push_rules_for_user(self, user: UserID) -> JsonDict:
async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
user_id = user.to_string()
rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
rules_raw = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules_raw)
return rules

async def ephemeral_by_room(
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDic
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rules = await self.store.get_push_rules_for_user(user_id)
rules_raw = await self.store.get_push_rules_for_user(user_id)

rules = format_push_rules_for_user(requester.user, rules)
rules = format_push_rules_for_user(requester.user, rules_raw)

path_parts = path.split("/")[1:]

Expand Down
4 changes: 2 additions & 2 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ async def get_current_users_in_room(
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)

async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)

async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> Set[str]:
) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids
Args:
Expand Down
8 changes: 1 addition & 7 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache

Expand Down Expand Up @@ -155,8 +151,6 @@ def __init__(
],
)

self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
Expand Down
56 changes: 28 additions & 28 deletions synapse/storage/databases/main/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
import calendar
import logging
import time
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, List, Tuple, cast

from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.types import Cursor

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -73,7 +76,7 @@ def __init__(

@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self) -> None:
def fetch(txn):
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
txn.execute(
"""
SELECT t1.c, t2.c
Expand All @@ -86,7 +89,7 @@ def fetch(txn):
) t2 ON t1.room_id = t2.room_id
"""
)
return txn.fetchall()
return cast(List[Tuple[int, int]], txn.fetchall())

res = await self.db_pool.runInteraction("read_forward_extremities", fetch)

Expand All @@ -104,20 +107,20 @@ async def count_daily_e2ee_messages(self) -> int:
call to this function, it will return None.
"""

def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)

async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
Expand All @@ -130,22 +133,22 @@ def _count_messages(txn):
"""

txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction(
"count_daily_sent_e2ee_messages", _count_messages
)

async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn):
def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction(
Expand All @@ -160,20 +163,20 @@ async def count_daily_messages(self) -> int:
call to this function, it will return None.
"""

def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_messages", _count_messages)

async def count_daily_sent_messages(self) -> int:
def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
Expand All @@ -186,22 +189,22 @@ def _count_messages(txn):
"""

txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages
)

async def count_daily_active_rooms(self) -> int:
def _count(txn):
def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
Expand All @@ -227,7 +230,7 @@ async def count_monthly_users(self) -> int:
"count_monthly_users", self._count_users, thirty_days_ago
)

def _count_users(self, txn: Cursor, time_from: int) -> int:
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
Expand All @@ -242,7 +245,7 @@ def _count_users(self, txn: Cursor, time_from: int) -> int:
# Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
(count,) = txn.fetchone() # type: ignore[misc]
(count,) = cast(Tuple[int], txn.fetchone())
return count

async def count_r30_users(self) -> Dict[str, int]:
Expand All @@ -256,7 +259,7 @@ async def count_r30_users(self) -> Dict[str, int]:
A mapping of counts globally as well as broken out by platform.
"""

def _count_r30_users(txn):
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
Expand Down Expand Up @@ -321,7 +324,7 @@ def _count_r30_users(txn):

txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))

(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count

return results
Expand All @@ -348,7 +351,7 @@ async def count_r30v2_users(self) -> Dict[str, int]:
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""

def _count_r30v2_users(txn):
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
Expand Down Expand Up @@ -445,11 +448,8 @@ def _count_r30v2_users(txn):
thirty_days_in_secs * 1000,
),
)
row = txn.fetchone()
if row is None:
results["all"] = 0
else:
results["all"] = row[0]
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count

return results

Expand All @@ -471,7 +471,7 @@ async def generate_user_daily_visits(self) -> None:
Generates daily visit data for use in cohort/ retention analysis
"""

def _generate_user_daily_visits(txn):
def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000
Expand Down
Loading

0 comments on commit 6edefef

Please sign in to comment.