Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert DeviceLastConnectionInfo to attrs. #16507

Merged
merged 3 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16507.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
23 changes: 7 additions & 16 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple

from synapse.api import errors
from synapse.api.constants import EduTypes, EventTypes
Expand All @@ -41,6 +31,7 @@
run_as_background_process,
wrap_as_background_process,
)
from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
from synapse.types import (
JsonDict,
JsonMapping,
Expand Down Expand Up @@ -1008,14 +999,14 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None:


def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
device: JsonDict, client_ips: Mapping[Tuple[str, str], DeviceLastConnectionInfo]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
ip = client_ips.get((device["user_id"], device["device_id"]))
device.update(
{
"last_seen_user_agent": ip.get("user_agent"),
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
"last_seen_user_agent": ip.user_agent if ip else None,
"last_seen_ts": ip.last_seen if ip else None,
"last_seen_ip": ip.ip if ip else None,
}
)

Expand Down
46 changes: 26 additions & 20 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast

import attr
from typing_extensions import TypedDict

from synapse.metrics.background_process_metrics import wrap_as_background_process
Expand Down Expand Up @@ -42,7 +43,8 @@
LAST_SEEN_GRANULARITY = 120 * 1000


class DeviceLastConnectionInfo(TypedDict):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLastConnectionInfo:
"""Metadata for the last connection seen for a user and device combination"""

# These types must match the columns in the `devices` table
Expand Down Expand Up @@ -499,24 +501,29 @@ async def _get_last_client_ip_by_device_from_database(
device_id: If None fetches all devices for the user

Returns:
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""

keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id

res = cast(
List[DeviceLastConnectionInfo],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)

return {(d["user_id"], d["device_id"]): d for d in res}
return {
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
user_id=d["user_id"],
device_id=d["device_id"],
ip=d["ip"],
user_agent=d["user_agent"],
last_seen=d["last_seen"],
)
for d in res
}

async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
Expand Down Expand Up @@ -683,8 +690,7 @@ async def get_last_client_ip_by_device(
device_id: If None fetches all devices for the user

Returns:
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)

Expand All @@ -705,13 +711,13 @@ async def get_last_client_ip_by_device(
continue

if not device_id or did == device_id:
ret[(user_id, did)] = {
"user_id": user_id,
"ip": ip,
"user_agent": user_agent,
"device_id": did,
"last_seen": last_seen,
}
ret[(user_id, did)] = DeviceLastConnectionInfo(
user_id=user_id,
ip=ip,
user_agent=user_agent,
device_id=did,
last_seen=last_seen,
)
return ret

async def get_user_ip_and_agents(
Expand Down
137 changes: 70 additions & 67 deletions tests/storage/test_client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.storage.databases.main.client_ips import (
LAST_SEEN_GRANULARITY,
DeviceLastConnectionInfo,
)
from synapse.types import UserID
from synapse.util import Clock

Expand Down Expand Up @@ -65,15 +68,15 @@ def test_insert_new_client_ip(self) -> None:
)

r = result[(user_id, device_id)]
self.assertLessEqual(
{
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip="ip",
user_agent="user_agent",
last_seen=12345678000,
),
r,
)

def test_insert_new_client_ip_none_device_id(self) -> None:
Expand Down Expand Up @@ -201,13 +204,13 @@ def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
self.assertEqual(
result,
{
(user_id, device_id): {
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
},
(user_id, device_id): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip="ip",
user_agent="user_agent",
last_seen=12345678000,
),
},
)

Expand Down Expand Up @@ -292,20 +295,20 @@ def test_get_last_client_ip_by_device_combined_data(self) -> None:
self.assertEqual(
result,
{
(user_id, device_id_1): {
"user_id": user_id,
"device_id": device_id_1,
"ip": "ip_1",
"user_agent": "user_agent_1",
"last_seen": 12345678000,
},
(user_id, device_id_2): {
"user_id": user_id,
"device_id": device_id_2,
"ip": "ip_2",
"user_agent": "user_agent_3",
"last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
},
(user_id, device_id_1): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id_1,
ip="ip_1",
user_agent="user_agent_1",
last_seen=12345678000,
),
(user_id, device_id_2): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id_2,
ip="ip_2",
user_agent="user_agent_3",
last_seen=12345688000 + LAST_SEEN_GRANULARITY,
),
},
)

Expand Down Expand Up @@ -526,15 +529,15 @@ def test_devices_last_seen_bg_update(self) -> None:
)

r = result[(user_id, device_id)]
self.assertLessEqual(
{
"user_id": user_id,
"device_id": device_id,
"ip": None,
"user_agent": None,
"last_seen": None,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip=None,
user_agent=None,
last_seen=None,
),
r,
)

# Register the background update to run again.
Expand All @@ -561,15 +564,15 @@ def test_devices_last_seen_bg_update(self) -> None:
)

r = result[(user_id, device_id)]
self.assertLessEqual(
{
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip="ip",
user_agent="user_agent",
last_seen=0,
),
r,
)

def test_old_user_ips_pruned(self) -> None:
Expand Down Expand Up @@ -640,15 +643,15 @@ def test_old_user_ips_pruned(self) -> None:
)

r = result2[(user_id, device_id)]
self.assertLessEqual(
{
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip="ip",
user_agent="user_agent",
last_seen=0,
),
r,
)

def test_invalid_user_agents_are_ignored(self) -> None:
Expand Down Expand Up @@ -777,13 +780,13 @@ def _runtest(
self.store.get_last_client_ip_by_device(self.user_id, device_id)
)
r = result[(self.user_id, device_id)]
self.assertLessEqual(
{
"user_id": self.user_id,
"device_id": device_id,
"ip": expected_ip,
"user_agent": "Mozzila pizza",
"last_seen": 123456100,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=self.user_id,
device_id=device_id,
ip=expected_ip,
user_agent="Mozzila pizza",
last_seen=123456100,
),
r,
)
Loading