Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Allow moving account data and receipts streams off master #9104

Merged
merged 14 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 1 addition & 39 deletions synapse/replication/slave/storage/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,9 @@
# limitations under the License.

from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore


class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"account_data",
"stream_id",
extra_tables=[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
)

super().__init__(database, db_conn, hs)

def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
pass
93 changes: 70 additions & 23 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import logging
from typing import Dict, List, Optional, Set, Tuple

from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
Expand All @@ -32,7 +34,7 @@

# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
class AccountDataWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
Expand All @@ -43,16 +45,58 @@ def __init__(self, database: DatabasePool, db_conn, hs):
"AccountDataAndTagsChangeCache", account_max
)

self._instance_name = hs.get_instance_name()

if isinstance(database.engine, PostgresEngine):
self._can_write_to_account_data = self._instance_name == "master"

self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="account_data",
instance_name=self._instance_name,
tables=[
("room_account_data", "instance_name", "stream_id"),
("room_tags_revisions", "instance_name", "stream_id"),
("account_data", "instance_name", "stream_id"),
],
sequence_name="account_data_sequence",
writers=["master"],
)
else:
self._can_write_to_account_data = True

# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
Comment on lines +64 to +70
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we're having lots of duplicated code here, can we abstract any of that? I'm unsure if they're 100% identical. 😢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If we can...probably best to do it as a follow-up, but between this, the HTTP endpoints, etc. it seems like a ton of boilerplate.)

if hs.get_instance_name() in hs.config.worker.writers.events:
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
else:
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)

super().__init__(database, db_conn, hs)

@abc.abstractmethod
def get_max_account_data_stream_id(self):
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
"""Get the current max stream ID for account data stream

Returns:
int
"""
raise NotImplementedError()
return self._account_data_id_gen.get_current_token()

@cached()
async def get_account_data_for_user(
Expand Down Expand Up @@ -307,26 +351,29 @@ async def ignored_by(self, user_id: str) -> Set[str]:
)
)

def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

class AccountDataStore(AccountDataWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)

super().__init__(database, db_conn, hs)

def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream

Returns:
The maximum stream ID.
"""
return self._account_data_id_gen.get_current_token()

class AccountDataStore(AccountDataWorkerStore):
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
Expand Down