Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add some type hints to datastore #12717

Merged
merged 16 commits into from
May 17, 2022
1 change: 1 addition & 0 deletions changelog.d/12717.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some type hints to datastore.
2 changes: 0 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
67 changes: 34 additions & 33 deletions synapse/storage/databases/main/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
import calendar
import logging
import time
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, List, Tuple, cast

from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
Expand Down Expand Up @@ -71,8 +75,8 @@ def __init__(
self._last_user_visit_update = self._get_start_of_day()

@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self):
def fetch(txn):
async def _read_forward_extremities(self) -> None:
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
txn.execute(
"""
SELECT t1.c, t2.c
Expand All @@ -85,7 +89,7 @@ def fetch(txn):
) t2 ON t1.room_id = t2.room_id
"""
)
return txn.fetchall()
return cast(List[Tuple[int, int]], txn.fetchall())

res = await self.db_pool.runInteraction("read_forward_extremities", fetch)

Expand All @@ -95,28 +99,28 @@ def fetch(txn):
(x[0] - 1) * x[1] for x in res if x[1]
)

async def count_daily_e2ee_messages(self):
async def count_daily_e2ee_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.

If it has been significantly less or more than one day since the last
call to this function, it will return None.
"""

def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)

async def count_daily_sent_e2ee_messages(self):
def _count_messages(txn):
async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
Expand All @@ -136,43 +140,43 @@ def _count_messages(txn):
"count_daily_sent_e2ee_messages", _count_messages
)

async def count_daily_active_e2ee_rooms(self):
def _count(txn):
async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction(
"count_daily_active_e2ee_rooms", _count
)

async def count_daily_messages(self):
async def count_daily_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.

If it has been significantly less or more than one day since the last
call to this function, it will return None.
"""

def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_messages", _count_messages)

async def count_daily_sent_messages(self):
def _count_messages(txn):
async def count_daily_sent_messages(self) -> int:
def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
Expand All @@ -192,15 +196,15 @@ def _count_messages(txn):
"count_daily_sent_messages", _count_messages
)

async def count_daily_active_rooms(self):
def _count(txn):
async def count_daily_active_rooms(self) -> int:
def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
Expand All @@ -226,7 +230,7 @@ async def count_monthly_users(self) -> int:
"count_monthly_users", self._count_users, thirty_days_ago
)

def _count_users(self, txn, time_from):
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
Expand All @@ -238,7 +242,7 @@ def _count_users(self, txn, time_from):
) u
"""
txn.execute(sql, (time_from,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

async def count_r30_users(self) -> Dict[str, int]:
Expand All @@ -252,7 +256,7 @@ async def count_r30_users(self) -> Dict[str, int]:
A mapping of counts globally as well as broken out by platform.
"""

def _count_r30_users(txn):
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
Expand Down Expand Up @@ -317,7 +321,7 @@ def _count_r30_users(txn):

txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))

(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count

return results
Expand All @@ -344,7 +348,7 @@ async def count_r30v2_users(self) -> Dict[str, int]:
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""

def _count_r30v2_users(txn):
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
Expand Down Expand Up @@ -441,19 +445,16 @@ def _count_r30v2_users(txn):
thirty_days_in_secs * 1000,
),
)
row = txn.fetchone()
if row is None:
results["all"] = 0
else:
results["all"] = row[0]
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

return results

return await self.db_pool.runInteraction(
"count_r30v2_users", _count_r30v2_users
)

def _get_start_of_day(self):
def _get_start_of_day(self) -> int:
"""
Returns millisecond unixtime for start of UTC day.
"""
Expand All @@ -467,7 +468,7 @@ async def generate_user_daily_visits(self) -> None:
Generates daily visit data for use in cohort/ retention analysis
"""

def _generate_user_daily_visits(txn):
def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000
Expand Down
Loading