Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Delete device messages asynchronously and in staged batches #16240

Merged
merged 13 commits into from
Sep 6, 2023
1 change: 1 addition & 0 deletions changelog.d/16240.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Delete device messages asynchronously and in staged batches using the task scheduler.
48 changes: 48 additions & 0 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@
)
from synapse.types import (
JsonDict,
JsonMapping,
ScheduledTask,
StrCollection,
StreamKeyType,
StreamToken,
TaskStatus,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
Expand All @@ -62,6 +65,7 @@

logger = logging.getLogger(__name__)

DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages"
MAX_DEVICE_DISPLAY_NAME_LEN = 100
DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000

Expand All @@ -78,6 +82,7 @@ def __init__(self, hs: "HomeServer"):
self._appservice_handler = hs.get_application_service_handler()
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self._event_sources = hs.get_event_sources()
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
self._query_appservices_for_keys = (
Expand Down Expand Up @@ -386,6 +391,7 @@ def __init__(self, hs: "HomeServer"):
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
self.db_pool = hs.get_datastores().main.db_pool
self._task_scheduler = hs.get_task_scheduler()

self.device_list_updater = DeviceListUpdater(hs, self)

Expand Down Expand Up @@ -419,6 +425,10 @@ def __init__(self, hs: "HomeServer"):
self._delete_stale_devices,
)

self._task_scheduler.register_action(
self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME
)

def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
Expand Down Expand Up @@ -530,6 +540,7 @@ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
user_id: The user to delete devices from.
device_ids: The list of device IDs to delete
"""
to_device_stream_id = self._event_sources.get_current_token().to_device_key

try:
await self.store.delete_devices(user_id, device_ids)
Expand Down Expand Up @@ -559,12 +570,49 @@ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
f"org.matrix.msc3890.local_notification_settings.{device_id}",
)

# Delete device messages asynchronously and in batches using the task scheduler
await self._task_scheduler.schedule_task(
DELETE_DEVICE_MSGS_TASK_NAME,
resource_id=device_id,
params={
"user_id": user_id,
"device_id": device_id,
"up_to_stream_id": to_device_stream_id,
},
)

# Pushers are deleted after `delete_access_tokens_for_user` is called so that
# modules using `on_logged_out` hook can use them if needed.
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)

await self.notify_device_update(user_id, device_ids)

DEVICE_MSGS_DELETE_BATCH_LIMIT = 100

async def _delete_device_messages(
self,
task: ScheduledTask,
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
"""Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
assert task.params is not None
user_id = task.params["user_id"]
device_id = task.params["device_id"]
up_to_stream_id = task.params["up_to_stream_id"]

res = await self.store.delete_messages_for_device(
user_id=user_id,
device_id=device_id,
up_to_stream_id=up_to_stream_id,
limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
)

if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
return TaskStatus.COMPLETE, None, None
else:
# There is probably still device messages to be deleted, let's keep the task active and it will be run
# again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
return TaskStatus.ACTIVE, None, None

async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
"""Update the given device

Expand Down
4 changes: 1 addition & 3 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class BasePresenceHandler(abc.ABC):
writer"""

def __init__(self, hs: "HomeServer"):
self.hs = hs
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
Expand Down Expand Up @@ -426,8 +427,6 @@ def __exit__(
class WorkerPresenceHandler(BasePresenceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs

self._presence_writer_instance = hs.config.worker.writers.presence[0]

# Route presence EDUs to the right worker
Expand Down Expand Up @@ -691,7 +690,6 @@ async def bump_presence_active_time(
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()

Expand Down
16 changes: 13 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME
from synapse.handlers.relations import BundledAggregations
from synapse.logging import issue9533_logger
from synapse.logging.context import current_context
Expand Down Expand Up @@ -268,6 +269,7 @@ def __init__(self, hs: "HomeServer"):
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._device_handler = hs.get_device_handler()
self._task_scheduler = hs.get_task_scheduler()

self.should_calculate_push_rules = hs.config.push.enable_push

Expand Down Expand Up @@ -360,11 +362,19 @@ async def _wait_for_sync_for_user(
# (since we now know that the device has received them)
if since_token is not None:
since_stream_id = since_token.to_device_key
deleted = await self.store.delete_messages_for_device(
sync_config.user.to_string(), sync_config.device_id, since_stream_id
# Delete device messages asynchronously and in batches using the task scheduler
await self._task_scheduler.schedule_task(
DELETE_DEVICE_MSGS_TASK_NAME,
resource_id=sync_config.device_id,
params={
"user_id": sync_config.user.to_string(),
"device_id": sync_config.device_id,
"up_to_stream_id": since_stream_id,
},
)
logger.debug(
"Deleted %d to-device messages up to %d", deleted, since_stream_id
"Deletion of to-device messages up to %d scheduled",
since_stream_id,
)

if timeout == 0 or since_token is None or full_state:
Expand Down
26 changes: 20 additions & 6 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,18 @@ def get_device_messages_txn(

@trace
async def delete_messages_for_device(
self, user_id: str, device_id: Optional[str], up_to_stream_id: int
self,
user_id: str,
device_id: Optional[str],
up_to_stream_id: int,
limit: int,
) -> int:
"""
Args:
user_id: The recipient user_id.
device_id: The recipient device_id.
up_to_stream_id: Where to delete messages up to.
limit: maximum number of messages to delete

Returns:
The number of messages deleted.
Expand All @@ -472,12 +477,16 @@ async def delete_messages_for_device(
log_kv({"message": "No changes in cache since last check"})
return 0

ROW_ID_NAME = self.database_engine.row_id_name

def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
sql = f"""
DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
SELECT {ROW_ID_NAME} FROM device_inbox
WHERE user_id = ? AND device_id = ? AND stream_id <= ?
LIMIT {limit}
)
"""
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount

Expand All @@ -487,6 +496,11 @@ def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:

log_kv({"message": f"deleted {count} messages for device", "count": count})

# In this case we don't know if we hit the limit or the delete is complete
# so let's not update the cache.
if count == limit:
return count

# Update the cache, ensuring that we only ever increase the value
updated_last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
Expand Down
8 changes: 0 additions & 8 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,14 +1766,6 @@ def _delete_devices_txn(txn: LoggingTransaction) -> None:
keyvalues={"user_id": user_id, "hidden": False},
)

self.db_pool.simple_delete_many_txn(
txn,
table="device_inbox",
column="device_id",
values=device_ids,
keyvalues={"user_id": user_id},
)

self.db_pool.simple_delete_many_txn(
txn,
table="device_auth_providers",
Expand Down
6 changes: 1 addition & 5 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,11 +939,7 @@ async def _background_receipts_linearized_unique_index(
receipts."""

def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
if isinstance(self.database_engine, PostgresEngine):
ROW_ID_NAME = "ctid"
else:
ROW_ID_NAME = "rowid"

ROW_ID_NAME = self.database_engine.row_id_name
# Identify any duplicate receipts arising from
# https://github.com/matrix-org/synapse/issues/14406.
# The following query takes less than a minute on matrix.org.
Expand Down
6 changes: 6 additions & 0 deletions synapse/storage/engines/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'"""
...

@property
@abc.abstractmethod
def row_id_name(self) -> str:
"""Gets the literal name representing a row id for this engine."""
...

@abc.abstractmethod
def in_transaction(self, conn: ConnectionType) -> bool:
"""Whether the connection is currently in a transaction."""
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/engines/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ def server_version(self) -> str:
else:
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)

@property
def row_id_name(self) -> str:
return "ctid"

def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
return conn.status != psycopg2.extensions.STATUS_READY

Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/engines/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'."""
return "%i.%i.%i" % sqlite3.sqlite_version_info

@property
def row_id_name(self) -> str:
return "rowid"

def in_transaction(self, conn: sqlite3.Connection) -> bool:
return conn.in_transaction

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.prepare_database import get_statements

FIX_INDEXES = """
Expand All @@ -37,7 +37,7 @@


def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
rowid = database_engine.row_id_name

# remove duplicates from group_users & group_invites tables
cur.execute(
Expand Down
17 changes: 7 additions & 10 deletions synapse/util/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class TaskScheduler:
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs

def __init__(self, hs: "HomeServer"):
self._hs = hs
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
self._running_tasks: Set[str] = set()
Expand All @@ -97,8 +98,6 @@ def __init__(self, hs: "HomeServer"):
"handle_scheduled_tasks",
self._handle_scheduled_tasks,
)
else:
self.replication_client = hs.get_replication_command_handler()
MatMaul marked this conversation as resolved.
Show resolved Hide resolved

def register_action(
self,
Expand Down Expand Up @@ -133,7 +132,7 @@ async def schedule_task(
params: Optional[JsonMapping] = None,
) -> str:
"""Schedule a new potentially resumable task. A function matching the specified
`action` should have been previously registered with `register_action`.
`action` should have be registered with `register_action` before the task is run.

Args:
action: the name of a previously registered action
Expand All @@ -149,11 +148,6 @@ async def schedule_task(
Returns:
The id of the scheduled task
"""
if action not in self._actions:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is moved inside _launch_task instead, since actions only need to be registered on the background worker and may not be registered on other ones.

raise Exception(
f"No function associated with action {action} of the scheduled task"
)

status = TaskStatus.SCHEDULED
if timestamp is None or timestamp < self._clock.time_msec():
timestamp = self._clock.time_msec()
Expand All @@ -175,7 +169,7 @@ async def schedule_task(
if self._run_background_tasks:
await self._launch_task(task)
else:
self.replication_client.send_new_active_task(task.id)
self._hs.get_replication_command_handler().send_new_active_task(task.id)

return task.id

Expand Down Expand Up @@ -315,7 +309,10 @@ async def _launch_task(self, task: ScheduledTask) -> None:
"""
assert self._run_background_tasks

assert task.action in self._actions
if task.action not in self._actions:
raise Exception(
f"No function associated with action {task.action} of the scheduled task {task.id}"
)
function = self._actions[task.action]

async def wrapper() -> None:
Expand Down
Loading
Loading