Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Remove more usages of cursor_to_dict. (#16551)
Browse files Browse the repository at this point in the history
Mostly to improve type safety.
  • Loading branch information
clokep authored Oct 26, 2023
1 parent 85e5f2d commit 679c691
Show file tree
Hide file tree
Showing 26 changed files with 193 additions and 134 deletions.
1 change: 1 addition & 0 deletions changelog.d/16551.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
18 changes: 9 additions & 9 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple

import attr

from synapse.api.errors import (
CodeMessageException,
Codes,
Expand Down Expand Up @@ -357,9 +359,9 @@ async def send_threepid_validation(

# Check to see if a session already exists and that it is not yet
# marked as validated
if session and session.get("validated_at") is None:
session_id = session["session_id"]
last_send_attempt = session["last_send_attempt"]
if session and session.validated_at is None:
session_id = session.session_id
last_send_attempt = session.last_send_attempt

# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
Expand Down Expand Up @@ -480,27 +482,25 @@ async def validate_threepid_session(

# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
validation_session = None

# Try to validate as email
if self.hs.config.email.can_verify_email:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)

if validation_session:
return validation_session
if validation_session:
return attr.asdict(validation_session)

# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)

return validation_session
return None

async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/ui_auth/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ async def _check_threepid(self, medium: str, authdict: dict) -> dict:

if row:
threepid = {
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
"medium": row.medium,
"address": row.address,
"validated_at": row.validated_at,
}

# Valid threepid returned, delete from the db
Expand Down
5 changes: 1 addition & 4 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,10 +949,7 @@ async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:

deleted = 0

for media in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
for origin, media_id, file_id in old_media:
key = (origin, media_id)

logger.info("Deleting: %r", key)
Expand Down
14 changes: 13 additions & 1 deletion synapse/rest/admin/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,19 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
)
response = {"destinations": destinations, "total": total}
response = {
"destinations": [
{
"destination": r[0],
"retry_last_ts": r[1],
"retry_interval": r[2],
"failure_ts": r[3],
"last_successful_stream_ordering": r[4],
}
for r in destinations
],
"total": total,
}
if (start + limit) < total:
response["next_token"] = str(start + len(destinations))

Expand Down
12 changes: 11 additions & 1 deletion synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,17 @@ async def on_GET(
room_id, _ = await self.resolve_room_id(room_identifier)

extremities = await self.store.get_forward_extremities_for_room(room_id)
return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
result = [
{
"event_id": ex[0],
"state_group": ex[1],
"depth": ex[2],
"received_ts": ex[3],
}
for ex in extremities
]

return HTTPStatus.OK, {"count": len(extremities), "results": result}


class RoomEventContextServlet(RestServlet):
Expand Down
13 changes: 12 additions & 1 deletion synapse/rest/admin/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,18 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
)
ret = {"users": users_media, "total": total}
ret = {
"users": [
{
"user_id": r[0],
"displayname": r[1],
"media_count": r[2],
"media_length": r[3],
}
for r in users_media
],
"total": total,
}
if (start + limit) < total:
ret["next_token"] = start + len(users_media)

Expand Down
30 changes: 3 additions & 27 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -1047,43 +1046,20 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
results = [dict(zip(col_headers, row)) for row in cursor]
return results

@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...

@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...

async def execute(
self,
desc: str,
decoder: Optional[Callable[[Cursor], R]],
query: str,
*args: Any,
) -> Union[List[Tuple[Any, ...]], R]:
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set.
Args:
desc: description of the transaction, for logging and metrics
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""

def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
txn.execute(query, args)
if decoder:
return decoder(txn)
else:
return txn.fetchall()
return txn.fetchall()

return await self.runInteraction(desc, interaction)

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/censor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def _censor_redactions(self) -> None:
"""

rows = await self.db_pool.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
"_censor_redactions_fetch", sql, before_ts, 100
)

updates = []
Expand Down
3 changes: 1 addition & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,6 @@ async def get_all_devices_changed(

rows = await self.db_pool.execute(
"get_all_devices_changed",
None,
sql,
from_key,
to_key,
Expand Down Expand Up @@ -978,7 +977,7 @@ async def get_users_whose_signatures_changed(
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
"get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:
Expand Down
1 change: 0 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ async def get_e2e_device_keys_for_federation_query(
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
None,
sql,
now_stream_id,
user_id,
Expand Down
7 changes: 2 additions & 5 deletions synapse/storage/databases/main/events_bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,12 +1310,9 @@ def process(txn: Cursor) -> None:

# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
# We need to pass execute a dummy function to handle the txn's result otherwise
# it tries to call fetchall() on it and fails because there's no result to fetch.
await self.db_pool.execute(
await self.db_pool.runInteraction(
"background_analyze_new_stream_ordering_column",
lambda txn: None,
"ANALYZE events(stream_ordering2)",
lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)

await self.db_pool.runInteraction(
Expand Down
15 changes: 10 additions & 5 deletions synapse/storage/databases/main/events_forward_extremities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, List
from typing import List, Optional, Tuple, cast

from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
Expand Down Expand Up @@ -91,12 +91,17 @@ def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int:

async def get_forward_extremities_for_room(
self, room_id: str
) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
) -> List[Tuple[str, int, int, Optional[int]]]:
"""
Get list of forward extremities for a room.
Returns:
A list of tuples of event_id, state_group, depth, and received_ts.
"""

def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
Expand All @@ -106,7 +111,7 @@ def get_forward_extremities_for_room_txn(
"""

txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())

return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
Expand Down
19 changes: 11 additions & 8 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ async def store_remote_media_thumbnail(

async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
) -> List[Dict[str, str]]:
) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
Expand All @@ -664,21 +664,24 @@ async def get_remote_media_ids(
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
* The filesystem ID.
"""

sql = """
SELECT media_origin, media_id, filesystem_id
FROM remote_media_cache
WHERE last_access_ts < ?
"""
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)

if include_quarantined_media is False:
# Only include media that has not been quarantined
sql += """
AND quarantined_by IS NULL
"""

return await self.db_pool.execute(
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
return cast(
List[Tuple[str, str, str]],
await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)

async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
Expand Down
Loading

0 comments on commit 679c691

Please sign in to comment.