Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Move additional tasks to the background worker #8458

Merged
merged 9 commits into from
Oct 7, 2020
1 change: 1 addition & 0 deletions changelog.d/8458.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow running background tasks in a separate worker process.
4 changes: 4 additions & 0 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
from synapse.storage.databases.main.metrics import ServerMetricsStore
from synapse.storage.databases.main.monthly_active_users import (
Expand All @@ -135,6 +136,7 @@
from synapse.storage.databases.main.presence import UserPresenceState
from synapse.storage.databases.main.search import SearchWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
Expand Down Expand Up @@ -466,6 +468,7 @@ class GenericWorkerSlavedStore(
SlavedAccountDataStore,
SlavedPusherStore,
CensorEventsStore,
ClientIpWorkerStore,
SlavedEventStore,
SlavedKeyStore,
RoomStore,
Expand All @@ -481,6 +484,7 @@ class GenericWorkerSlavedStore(
MediaRepositoryStore,
ServerMetricsStore,
SearchWorkerStore,
TransactionWorkerStore,
BaseSlavedStore,
):
pass
Expand Down
109 changes: 57 additions & 52 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,63 @@ def _devices_last_seen_update_txn(txn):
return updated


class ClientIpStore(ClientIpBackgroundUpdateStore):
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was unsure which of these should be the base class? It seems we do have a few layouts:

  • worker -> background update -> main
  • worker -> main
  • (worker, background update) -> main
  • background update -> worker -> main

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on if functions in the background updater depend on stuff from the worker or vice versa 🤷

def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)

self.user_ips_max_age = hs.config.user_ips_max_age

if hs.config.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)

@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
"""Removes entries in user IPs older than the configured period.
"""

if self.user_ips_max_age is None:
# Nothing to do
return

if not await self.db_pool.updates.has_completed_background_update(
"devices_last_seen"
):
# Only start pruning if we have finished populating the devices
# last seen info.
return

# We do a slightly funky SQL delete to ensure we don't try and delete
# too much at once (as the table may be very large from before we
# started pruning).
#
# This works by finding the max last_seen that is less than the given
# time, but has no more than N rows before it, deleting all rows with
# a lesser last_seen time. (We COALESCE so that the sub-SELECT always
# returns exactly one row).
sql = """
DELETE FROM user_ips
WHERE last_seen <= (
SELECT COALESCE(MAX(last_seen), -1)
FROM (
SELECT last_seen FROM user_ips
WHERE last_seen <= ?
ORDER BY last_seen ASC
LIMIT 5000
) AS u
)
"""

timestamp = self.clock.time_msec() - self.user_ips_max_age

def _prune_old_user_ips_txn(txn):
txn.execute(sql, (timestamp,))

await self.db_pool.runInteraction(
"_prune_old_user_ips", _prune_old_user_ips_txn
)


class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):

self.client_ip_last_seen = Cache(
Expand All @@ -360,8 +416,6 @@ def __init__(self, database: DatabasePool, db_conn, hs):

super().__init__(database, db_conn, hs)

self.user_ips_max_age = hs.config.user_ips_max_age

# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}

Expand All @@ -372,9 +426,6 @@ def __init__(self, database: DatabasePool, db_conn, hs):
"before", "shutdown", self._update_client_ips_batch
)

if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)

async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
Expand Down Expand Up @@ -525,49 +576,3 @@ async def get_user_ip_and_agents(self, user):
}
for (access_token, ip), (user_agent, last_seen) in results.items()
]

@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
"""Removes entries in user IPs older than the configured period.
"""

if self.user_ips_max_age is None:
# Nothing to do
return

if not await self.db_pool.updates.has_completed_background_update(
"devices_last_seen"
):
# Only start pruning if we have finished populating the devices
# last seen info.
return

# We do a slightly funky SQL delete to ensure we don't try and delete
# too much at once (as the table may be very large from before we
# started pruning).
#
# This works by finding the max last_seen that is less than the given
# time, but has no more than N rows before it, deleting all rows with
# a lesser last_seen time. (We COALESCE so that the sub-SELECT always
# returns exactly one row).
sql = """
DELETE FROM user_ips
WHERE last_seen <= (
SELECT COALESCE(MAX(last_seen), -1)
FROM (
SELECT last_seen FROM user_ips
WHERE last_seen <= ?
ORDER BY last_seen ASC
LIMIT 5000
) AS u
)
"""

timestamp = self.clock.time_msec() - self.user_ips_max_age

def _prune_old_user_ips_txn(txn):
txn.execute(sql, (timestamp,))

await self.db_pool.runInteraction(
"_prune_old_user_ips", _prune_old_user_ips_txn
)
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def read_forward_extremities():
"read_forward_extremities", self._read_forward_extremities
)

hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
if hs.config.run_background_tasks:
self._clock.looping_call(read_forward_extremities, 60 * 60 * 1000)

# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
Expand Down
183 changes: 92 additions & 91 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ def __init__(self, database: DatabasePool, db_conn, hs):
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)

self._account_validity = hs.config.account_validity
if hs.config.run_background_tasks and self._account_validity.enabled:
self._clock.call_later(
0.0,
run_as_background_process,
"account_validity_set_expiration_dates",
self._set_expiration_date_when_missing,
)

# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)

if hs.config.run_background_tasks:
self.clock.looping_call(start_cull, THIRTY_MINUTES_IN_MS)

@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
Expand Down Expand Up @@ -778,6 +799,77 @@ def delete_threepid_session_txn(txn):
"delete_threepid_session", delete_threepid_session_txn
)

async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""

def cull_expired_threepid_validation_tokens_txn(txn, ts):
sql = """
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
txn.execute(sql, (ts,))

await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
)

async def _set_expiration_date_when_missing(self):
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
"""

def select_users_with_no_expiration_date_txn(txn):
"""Retrieves the list of registered users with no expiration date from the
database, filtering out deactivated users.
"""
sql = (
"SELECT users.name FROM users"
" LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
" WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
)
txn.execute(sql, [])

res = self.db_pool.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)

await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)

def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
"""Sets an expiration date to the account with the given user ID.

Args:
user_id (str): User ID to set an expiration date for.
use_delta (bool): If set to False, the expiration date for the user will be
now + validity period. If set to True, this expiration date will be a
random value in the [now + period - d ; now + period] range, d being a
delta equal to 10% of the validity period.
"""
now_ms = self._clock.time_msec()
expiration_ts = now_ms + self._account_validity.period

if use_delta:
expiration_ts = self.rand.randrange(
expiration_ts - self._account_validity.startup_job_max_delta,
expiration_ts,
)

self.db_pool.simple_upsert_txn(
txn,
"account_validity",
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)


class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
Expand Down Expand Up @@ -911,28 +1003,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)

self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors

if self._account_validity.enabled:
self._clock.call_later(
0.0,
run_as_background_process,
"account_validity_set_expiration_dates",
self._set_expiration_date_when_missing,
)

# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)

hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)

async def add_access_token_to_user(
self,
user_id: str,
Expand Down Expand Up @@ -1447,22 +1519,6 @@ def start_or_continue_validation_session_txn(txn):
start_or_continue_validation_session_txn,
)

async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""

def cull_expired_threepid_validation_tokens_txn(txn, ts):
sql = """
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
txn.execute(sql, (ts,))

await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
)

async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
Expand Down Expand Up @@ -1492,61 +1548,6 @@ def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
)
txn.call_after(self.is_guest.invalidate, (user_id,))

async def _set_expiration_date_when_missing(self):
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
"""

def select_users_with_no_expiration_date_txn(txn):
"""Retrieves the list of registered users with no expiration date from the
database, filtering out deactivated users.
"""
sql = (
"SELECT users.name FROM users"
" LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
" WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
)
txn.execute(sql, [])

res = self.db_pool.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)

await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)

def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
"""Sets an expiration date to the account with the given user ID.

Args:
user_id (str): User ID to set an expiration date for.
use_delta (bool): If set to False, the expiration date for the user will be
now + validity period. If set to True, this expiration date will be a
random value in the [now + period - d ; now + period] range, d being a
delta equal to 10% of the validity period.
"""
now_ms = self._clock.time_msec()
expiration_ts = now_ms + self._account_validity.period

if use_delta:
expiration_ts = self.rand.randrange(
expiration_ts - self._account_validity.startup_job_max_delta,
expiration_ts,
)

self.db_pool.simple_upsert_txn(
txn,
"account_validity",
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)


def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
Expand Down
Loading