Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert simple_select_one_txn and simple_select_one to return tuples. (
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Nov 9, 2023
1 parent ff716b4 commit ab3f1b3
Show file tree
Hide file tree
Showing 33 changed files with 283 additions and 279 deletions.
1 change: 1 addition & 0 deletions changelog.d/16612.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
3 changes: 1 addition & 2 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
backward_chunk = 0
already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
forward_chunk, backward_chunk = row

if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port(
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ async def _upgrade_room(
self,
requester: Requester,
old_room_id: str,
old_room: Dict[str, Any],
old_room: Tuple[bool, str, bool],
new_room_id: str,
new_version: RoomVersion,
tombstone_event: EventBase,
Expand All @@ -279,7 +279,7 @@ async def _upgrade_room(
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
old_room: a dict containing room information for the room to be replaced,
old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room
new_version: the version to upgrade the room to
Expand All @@ -299,7 +299,7 @@ async def _upgrade_room(
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
is_public=old_room["is_public"],
is_public=old_room[0],
room_version=new_version,
)

Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,8 @@ async def transfer_room_state_on_room_upgrade(
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
if old_room is not None and old_room["is_public"]:
# If the old room exists and is public.
if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)

Expand Down
3 changes: 2 additions & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,8 @@ async def room_is_in_public_room_list(self, room_id: str) -> bool:
if not room:
return False

return room.get("is_public", False)
# The first item is whether the room is public.
return room[0]

async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list.
Expand Down
8 changes: 4 additions & 4 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ async def on_GET(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")

members = await self.store.get_users_in_room(room_id)
Expand Down Expand Up @@ -442,8 +442,8 @@ async def on_GET(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")

event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
if room is None:
raise NotFoundError("Unknown room")

return 200, {"visibility": "public" if room["is_public"] else "private"}
return 200, {"visibility": "public" if room[0] else "private"}

class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public"
Expand Down
10 changes: 5 additions & 5 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
) -> Tuple[Any, ...]:
...

@overload
Expand All @@ -1608,7 +1608,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
...

async def simple_select_one(
Expand All @@ -1618,7 +1618,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
Expand Down Expand Up @@ -2127,7 +2127,7 @@ def simple_select_one_txn(
keyvalues: Dict[str, Any],
retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)

if keyvalues:
Expand All @@ -2145,7 +2145,7 @@ def simple_select_one_txn(
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))

return dict(zip(retcols, row))
return row

async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
Expand Down
43 changes: 13 additions & 30 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,33 +255,16 @@ async def get_device(
A dict containing the device information, or `None` if the device does not
exist.
"""
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)

async def get_device_opt(
self, user_id: str, device_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
if row is None:
return None
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}

async def get_devices_by_user(
self, user_id: str
Expand Down Expand Up @@ -1221,9 +1204,7 @@ async def get_dehydrated_device(
retcols=["device_id", "device_data"],
allow_none=True,
)
return (
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
)
return (row[0], json_decoder.decode(row[1])) if row else None

def _store_dehydrated_device_txn(
self,
Expand Down Expand Up @@ -2326,13 +2307,15 @@ async def get_device_change_last_converted_pos(self) -> Tuple[int, str]:
`FALSE` have not been converted.
"""

row = await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
return cast(
Tuple[int, str],
await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
),
)
return row["stream_id"], row["room_id"]

async def set_device_change_last_converted_pos(
self,
Expand Down
31 changes: 19 additions & 12 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,19 +506,26 @@ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
# it isn't there.
raise StoreError(404, "No backup with that version exists")

result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
row = cast(
Tuple[int, str, str, Optional[int]],
self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={
"user_id": user_id,
"version": this_version,
"deleted": 0,
},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
),
)
assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
return result
return {
"auth_data": db_to_json(row[2]),
"version": str(row[0]),
"algorithm": row[1],
"etag": 0 if row[3] is None else row[3],
}

return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
Expand Down
4 changes: 1 addition & 3 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,9 +1266,7 @@ async def _claim_e2e_fallback_keys_simple(
if row is None:
continue

key_id = row["key_id"]
key_json = row["key_json"]
used = row["used"]
key_id, key_json, used = row

# Mark fallback key as used if not already.
if not used and mark_as_used:
Expand Down
24 changes: 10 additions & 14 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ async def get_auth_chain_ids(
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
Expand Down Expand Up @@ -411,7 +412,8 @@ async def get_auth_chain_difference(
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
Expand Down Expand Up @@ -1437,24 +1439,18 @@ def _get_backfill_events(
)

if event_lookup_result is not None:
event_type, depth, stream_ordering = event_lookup_result
logger.debug(
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
room_id,
seed_event_id,
event_lookup_result["depth"],
event_lookup_result["stream_ordering"],
event_lookup_result["type"],
depth,
stream_ordering,
event_type,
)

if event_lookup_result["depth"]:
queue.put(
(
-event_lookup_result["depth"],
-event_lookup_result["stream_ordering"],
seed_event_id,
event_lookup_result["type"],
)
)
if depth:
queue.put((-depth, -stream_ordering, seed_event_id, event_type))

while not queue.empty() and len(event_id_results) < limit:
try:
Expand Down
3 changes: 1 addition & 2 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,8 +1934,7 @@ def _handle_redact_relations(
if row is None:
return

redacted_relates_to = row["relates_to_id"]
rel_type = row["relation_type"]
redacted_relates_to, rel_type = row
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,7 +1998,7 @@ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))

return int(res["topological_ordering"]), int(res["stream_ordering"])
return int(res[0]), int(res[1])

async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
Expand Down
30 changes: 23 additions & 7 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,17 @@ async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
)
if row is None:
return None
return LocalMedia(media_id=media_id, **row)
return LocalMedia(
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
quarantined_by=row[4],
url_cache=row[5],
last_access_ts=row[6],
safe_from_quarantine=row[7],
)

async def get_local_media_by_user_paginate(
self,
Expand Down Expand Up @@ -541,7 +551,17 @@ async def get_cached_remote_media(
)
if row is None:
return row
return RemoteMedia(media_origin=origin, media_id=media_id, **row)
return RemoteMedia(
media_origin=origin,
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
)

async def store_cached_remote_media(
self,
Expand Down Expand Up @@ -665,11 +685,7 @@ async def get_remote_media_thumbnail(
if row is None:
return None
return ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)

@trace
Expand Down
Loading

0 comments on commit ab3f1b3

Please sign in to comment.