diff --git a/changelog.d/19037.misc b/changelog.d/19037.misc new file mode 100644 index 00000000000..763050067ef --- /dev/null +++ b/changelog.d/19037.misc @@ -0,0 +1 @@ +Move unique snowflake homeserver background tasks to `start_background_tasks` (the standard pattern for this kind of thing). diff --git a/synapse/app/_base.py b/synapse/app/_base.py index a3e4b4ea4b6..b416b66ac6e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -64,7 +64,6 @@ import synapse.util.caches from synapse.api.constants import MAX_PDU_SIZE from synapse.app import check_bind_error -from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config import ConfigError from synapse.config._base import format_config_error from synapse.config.homeserver import HomeServerConfig @@ -683,15 +682,6 @@ def log_shutdown() -> None: if hs.config.worker.run_background_tasks: hs.start_background_tasks() - # TODO: This should be moved to same pattern we use for other background tasks: - # Add to `REQUIRED_ON_BACKGROUND_TASK_STARTUP` and rely on - # `start_background_tasks` to start it. - await hs.get_common_usage_metrics_manager().setup() - - # TODO: This feels like another pattern that should refactored as one of the - # `REQUIRED_ON_BACKGROUND_TASK_STARTUP` - start_phone_stats_home(hs) - if freeze: # We now freeze all allocated objects in the hopes that (almost) # everything currently allocated are things that will be used for the diff --git a/synapse/metrics/common_usage_metrics.py b/synapse/metrics/common_usage_metrics.py index 43e0913d279..3f38412fa7a 100644 --- a/synapse/metrics/common_usage_metrics.py +++ b/synapse/metrics/common_usage_metrics.py @@ -62,7 +62,7 @@ async def get_metrics(self) -> CommonUsageMetrics: """ return await self._collect() - async def setup(self) -> None: + def setup(self) -> None: """Keep the gauges for common usage metrics up to date.""" self._hs.run_as_background_process( desc="common_usage_metrics_update_gauges", diff --git a/synapse/server.py b/synapse/server.py index 1316249dda1..b63a11273a7 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -62,6 +62,7 @@ from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter from synapse.app._base import unregister_sighups +from synapse.app.phone_stats_home import start_phone_stats_home from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.config.homeserver import HomeServerConfig @@ -643,6 +644,8 @@ def start_background_tasks(self) -> None: for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP: getattr(self, "get_" + i + "_handler")() self.get_task_scheduler() + self.get_common_usage_metrics_manager().setup() + start_phone_stats_home(self) def get_reactor(self) -> ISynapseReactor: """ diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 1a2dab4c7d7..8a6394e9ef7 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -214,7 +214,12 @@ def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest: client_to_server_transport.loseConnection() # there should have been exactly one request - self.assertEqual(len(requests), 1) + self.assertEqual( + len(requests), + 1, + "Expected to handle exactly one HTTP replication request but saw %d - requests=%s" + % (len(requests), requests), + ) return requests[0] diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index 6dea29ae153..d0c189230c4 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -46,28 +46,39 @@ def test_update_function_room_account_data_limit(self) -> None: # check we're testing what we think we are: no rows should yet have been # received - self.assertEqual([], self.test_handler.received_rdata_rows) + received_account_data_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == AccountDataStream.NAME + ] + self.assertEqual([], received_account_data_rows) # now reconnect to pull the updates self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - received_rows = self.test_handler.received_rdata_rows + # We should have received all the expected rows in the right order + # + # Filter the updates to only include account data changes + received_account_data_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == AccountDataStream.NAME + ] for t in updates: - (stream_name, token, row) = received_rows.pop(0) + (stream_name, token, row) = received_account_data_rows.pop(0) self.assertEqual(stream_name, AccountDataStream.NAME) self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) self.assertEqual(row.data_type, t) self.assertEqual(row.room_id, "test_room") - (stream_name, token, row) = received_rows.pop(0) + (stream_name, token, row) = received_account_data_rows.pop(0) self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) self.assertEqual(row.data_type, "m.global") self.assertIsNone(row.room_id) - self.assertEqual([], received_rows) + self.assertEqual([], received_account_data_rows) def test_update_function_global_account_data_limit(self) -> None: """Test replication with many global account data updates""" @@ -85,32 +96,38 @@ def test_update_function_global_account_data_limit(self) -> None: store.add_account_data_to_room("test_user", "test_room", "m.per_room", {}) ) - # tell the notifier to catch up to avoid duplicate rows. - # workaround for https://github.com/matrix-org/synapse/issues/7360 - # FIXME remove this when the above is fixed - self.replicate() - # check we're testing what we think we are: no rows should yet have been # received - self.assertEqual([], self.test_handler.received_rdata_rows) + received_account_data_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == AccountDataStream.NAME + ] + self.assertEqual([], received_account_data_rows) # now reconnect to pull the updates self.reconnect() self.replicate() # we should have received all the expected rows in the right order - received_rows = self.test_handler.received_rdata_rows + # + # Filter the updates to only include typing changes + received_account_data_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == AccountDataStream.NAME + ] for t in updates: - (stream_name, token, row) = received_rows.pop(0) + (stream_name, token, row) = received_account_data_rows.pop(0) self.assertEqual(stream_name, AccountDataStream.NAME) self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) self.assertEqual(row.data_type, t) self.assertIsNone(row.room_id) - (stream_name, token, row) = received_rows.pop(0) + (stream_name, token, row) = received_account_data_rows.pop(0) self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) self.assertEqual(row.data_type, "m.per_room") self.assertEqual(row.room_id, "test_room") - self.assertEqual([], received_rows) + self.assertEqual([], received_account_data_rows) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 782dad39f5c..452032205f9 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -30,6 +30,7 @@ from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT from synapse.replication.tcp.streams.events import ( _MAX_STATE_UPDATES_PER_ROOM, + EventsStream, EventsStreamAllStateRow, EventsStreamCurrentStateRow, EventsStreamEventRow, @@ -82,7 +83,12 @@ def test_update_function_event_row_limit(self) -> None: # check we're testing what we think we are: no rows should yet have been # received - self.assertEqual([], self.test_handler.received_rdata_rows) + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME + ] + self.assertEqual([], received_event_rows) # now reconnect to pull the updates self.reconnect() @@ -90,31 +96,34 @@ def test_update_function_event_row_limit(self) -> None: # we should have received all the expected rows in the right order (as # well as various cache invalidation updates which we ignore) - received_rows = [ - row for row in self.test_handler.received_rdata_rows if row[0] == "events" + # + # Filter the updates to only include event changes + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME ] - for event in events: - stream_name, token, row = received_rows.pop(0) - self.assertEqual("events", stream_name) + stream_name, token, row = received_event_rows.pop(0) + self.assertEqual(EventsStream.NAME, stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "ev") self.assertIsInstance(row.data, EventsStreamEventRow) self.assertEqual(row.data.event_id, event.event_id) - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertIsInstance(row, EventsStreamRow) self.assertIsInstance(row.data, EventsStreamEventRow) self.assertEqual(row.data.event_id, state_event.event_id) - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "state") self.assertIsInstance(row.data, EventsStreamCurrentStateRow) self.assertEqual(row.data.event_id, state_event.event_id) - self.assertEqual([], received_rows) + self.assertEqual([], received_event_rows) @parameterized.expand( [(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)] @@ -170,9 +179,12 @@ def test_update_function_huge_state_change( self.replicate() # all those events and state changes should have landed - self.assertGreaterEqual( - len(self.test_handler.received_rdata_rows), 2 * len(events) - ) + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME + ] + self.assertGreaterEqual(len(received_event_rows), 2 * len(events)) # disconnect, so that we can stack up the changes self.disconnect() @@ -202,7 +214,12 @@ def test_update_function_huge_state_change( # check we're testing what we think we are: no rows should yet have been # received - self.assertEqual([], self.test_handler.received_rdata_rows) + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME + ] + self.assertEqual([], received_event_rows) # now reconnect to pull the updates self.reconnect() @@ -218,33 +235,34 @@ def test_update_function_huge_state_change( # of the states that got reverted. # - two rows for state2 - received_rows = [ - row for row in self.test_handler.received_rdata_rows if row[0] == "events" + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME ] - # first check the first two rows, which should be the state1 event. - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "ev") self.assertIsInstance(row.data, EventsStreamEventRow) self.assertEqual(row.data.event_id, state1.event_id) - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "state") self.assertIsInstance(row.data, EventsStreamCurrentStateRow) self.assertEqual(row.data.event_id, state1.event_id) # now the last two rows, which should be the state2 event. - stream_name, token, row = received_rows.pop(-2) + stream_name, token, row = received_event_rows.pop(-2) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "ev") self.assertIsInstance(row.data, EventsStreamEventRow) self.assertEqual(row.data.event_id, state2.event_id) - stream_name, token, row = received_rows.pop(-1) + stream_name, token, row = received_event_rows.pop(-1) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "state") self.assertIsInstance(row.data, EventsStreamCurrentStateRow) @@ -254,16 +272,16 @@ def test_update_function_huge_state_change( if collapse_state_changes: # that should leave us with the rows for the PL event, the state changes # get collapsed into a single row. - self.assertEqual(len(received_rows), 2) + self.assertEqual(len(received_event_rows), 2) - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "ev") self.assertIsInstance(row.data, EventsStreamEventRow) self.assertEqual(row.data.event_id, pl_event.event_id) - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "state-all") self.assertIsInstance(row.data, EventsStreamAllStateRow) @@ -271,9 +289,9 @@ def test_update_function_huge_state_change( else: # that should leave us with the rows for the PL event - self.assertEqual(len(received_rows), len(events) + 2) + self.assertEqual(len(received_event_rows), len(events) + 2) - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "ev") @@ -282,7 +300,7 @@ def test_update_function_huge_state_change( # the state rows are unsorted state_rows: List[EventsStreamCurrentStateRow] = [] - for stream_name, _, row in received_rows: + for stream_name, _, row in received_event_rows: self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "state") @@ -346,9 +364,12 @@ def test_update_function_state_row_limit(self) -> None: self.replicate() # all those events and state changes should have landed - self.assertGreaterEqual( - len(self.test_handler.received_rdata_rows), 2 * len(events) - ) + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME + ] + self.assertGreaterEqual(len(received_event_rows), 2 * len(events)) # disconnect, so that we can stack up the changes self.disconnect() @@ -375,7 +396,12 @@ def test_update_function_state_row_limit(self) -> None: # check we're testing what we think we are: no rows should yet have been # received - self.assertEqual([], self.test_handler.received_rdata_rows) + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME + ] + self.assertEqual([], received_event_rows) # now reconnect to pull the updates self.reconnect() @@ -383,14 +409,16 @@ def test_update_function_state_row_limit(self) -> None: # we should have received all the expected rows in the right order (as # well as various cache invalidation updates which we ignore) - received_rows = [ - row for row in self.test_handler.received_rdata_rows if row[0] == "events" + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME ] - self.assertGreaterEqual(len(received_rows), len(events)) + self.assertGreaterEqual(len(received_event_rows), len(events)) for i in range(NUM_USERS): # for each user, we expect the PL event row, followed by state rows for # the PL event and each of the states that got reverted. - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "ev") @@ -400,7 +428,7 @@ def test_update_function_state_row_limit(self) -> None: # the state rows are unsorted state_rows: List[EventsStreamCurrentStateRow] = [] for _ in range(STATES_PER_USER + 1): - stream_name, token, row = received_rows.pop(0) + stream_name, token, row = received_event_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "state") @@ -417,7 +445,7 @@ def test_update_function_state_row_limit(self) -> None: # "None" indicates the state has been deleted self.assertIsNone(sr.event_id) - self.assertEqual([], received_rows) + self.assertEqual([], received_event_rows) def test_backwards_stream_id(self) -> None: """ @@ -432,7 +460,12 @@ def test_backwards_stream_id(self) -> None: # check we're testing what we think we are: no rows should yet have been # received - self.assertEqual([], self.test_handler.received_rdata_rows) + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME + ] + self.assertEqual([], received_event_rows) # now reconnect to pull the updates self.reconnect() @@ -440,14 +473,16 @@ def test_backwards_stream_id(self) -> None: # We should have received the expected single row (as well as various # cache invalidation updates which we ignore). - received_rows = [ - row for row in self.test_handler.received_rdata_rows if row[0] == "events" + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME ] # There should be a single received row. - self.assertEqual(len(received_rows), 1) + self.assertEqual(len(received_event_rows), 1) - stream_name, token, row = received_rows[0] + stream_name, token, row = received_event_rows[0] self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "ev") @@ -468,10 +503,12 @@ def test_backwards_stream_id(self) -> None: ) # No updates have been received (because it was discard as old). - received_rows = [ - row for row in self.test_handler.received_rdata_rows if row[0] == "events" + received_event_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == EventsStream.NAME ] - self.assertEqual(len(received_rows), 0) + self.assertEqual(len(received_event_rows), 0) # Ensure the stream has not gone backwards. current_token = worker_events_stream.current_token("master") diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py index fd81e0dc173..172968c1085 100644 --- a/tests/replication/tcp/streams/test_federation.py +++ b/tests/replication/tcp/streams/test_federation.py @@ -38,24 +38,45 @@ def test_catchup(self) -> None: Makes sure that updates sent while we are offline are received later. """ fed_sender = self.hs.get_federation_sender() - received_rows = self.test_handler.received_rdata_rows + # Send an update before we connect fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"}) + # Now reconnect and pull the updates self.reconnect() + # FIXME: This seems odd, why aren't we calling `self.replicate()` here? but also + # doing so, causes other assumptions to fail (multiple HTTP replication attempts + # are made). self.reactor.advance(0) - # check we're testing what we think we are: no rows should yet have been + # Check we're testing what we think we are: no rows should yet have been # received - self.assertEqual(received_rows, []) + # + # Filter the updates to only include typing changes + received_federation_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == FederationStream.NAME + ] + self.assertEqual(received_federation_rows, []) # We should now see an attempt to connect to the master request = self.handle_http_replication_attempt() - self.assert_request_is_get_repl_stream_updates(request, "federation") + self.assert_request_is_get_repl_stream_updates(request, FederationStream.NAME) # we should have received an update row - stream_name, token, row = received_rows.pop() - self.assertEqual(stream_name, "federation") + received_federation_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == FederationStream.NAME + ] + self.assertEqual( + len(received_federation_rows), + 1, + "Expected exactly one row for the federation stream", + ) + (stream_name, token, row) = received_federation_rows[0] + self.assertEqual(stream_name, FederationStream.NAME) self.assertIsInstance(row, FederationStream.FederationStreamRow) self.assertEqual(row.type, EduRow.TypeId) edurow = EduRow.from_data(row.data) @@ -63,19 +84,30 @@ def test_catchup(self) -> None: self.assertEqual(edurow.edu.origin, self.hs.hostname) self.assertEqual(edurow.edu.destination, "testdest") self.assertEqual(edurow.edu.content, {"a": "b"}) - - self.assertEqual(received_rows, []) + # Clear out the received rows that we've checked so we can check for new ones later + self.test_handler.received_rdata_rows.clear() # additional updates should be transferred without an HTTP hit fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"}) - self.reactor.advance(0) + # Pull in the updates + self.replicate() + # there should be no http hit self.assertEqual(len(self.reactor.tcpClients), 0) - # ... but we should have a row - self.assertEqual(len(received_rows), 1) - stream_name, token, row = received_rows.pop() - self.assertEqual(stream_name, "federation") + # ... but we should have a row + received_federation_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == FederationStream.NAME + ] + self.assertEqual( + len(received_federation_rows), + 1, + "Expected exactly one row for the federation stream", + ) + (stream_name, token, row) = received_federation_rows[0] + self.assertEqual(stream_name, FederationStream.NAME) self.assertIsInstance(row, FederationStream.FederationStreamRow) self.assertEqual(row.type, EduRow.TypeId) edurow = EduRow.from_data(row.data) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index c2f1f8dc4a9..c5332f6b5fb 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -20,7 +20,6 @@ # type: ignore -from unittest.mock import Mock from synapse.replication.tcp.streams._base import ReceiptsStream @@ -30,9 +29,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): - def _build_replication_data_handler(self): - return Mock(wraps=super()._build_replication_data_handler()) - def test_receipt(self): self.reconnect() @@ -50,23 +46,30 @@ def test_receipt(self): self.replicate() # there should be one RDATA command - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] - self.assertEqual(stream_name, "receipts") - self.assertEqual(1, len(rdata_rows)) - row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] + received_receipt_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == ReceiptsStream.NAME + ] + self.assertEqual( + len(received_receipt_rows), + 1, + "Expected exactly one row for the receipts stream", + ) + (stream_name, token, row) = received_receipt_rows[0] + self.assertEqual(stream_name, ReceiptsStream.NAME) self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) self.assertEqual("$event:blue", row.event_id) self.assertIsNone(row.thread_id) self.assertEqual({"a": 1}, row.data) + # Clear out the received rows that we've checked so we can check for new ones later + self.test_handler.received_rdata_rows.clear() # Now let's disconnect and insert some data. self.disconnect() - self.test_handler.on_rdata.reset_mock() - self.get_success( self.hs.get_datastores().main.insert_receipt( "!room2:blue", @@ -79,20 +82,27 @@ def test_receipt(self): ) self.replicate() - # Nothing should have happened as we are disconnected - self.test_handler.on_rdata.assert_not_called() + # Not yet connected: no rows should yet have been received + self.assertEqual([], self.test_handler.received_rdata_rows) + # Now reconnect and pull the updates self.reconnect() - self.pump(0.1) + self.replicate() # We should now have caught up and get the missing data - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] - self.assertEqual(stream_name, "receipts") + received_receipt_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == ReceiptsStream.NAME + ] + self.assertEqual( + len(received_receipt_rows), + 1, + "Expected exactly one row for the receipts stream", + ) + (stream_name, token, row) = received_receipt_rows[0] + self.assertEqual(stream_name, ReceiptsStream.NAME) self.assertEqual(token, 3) - self.assertEqual(1, len(rdata_rows)) - - row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room2:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) diff --git a/tests/replication/tcp/streams/test_thread_subscriptions.py b/tests/replication/tcp/streams/test_thread_subscriptions.py index 04e46b9d93d..5405316048b 100644 --- a/tests/replication/tcp/streams/test_thread_subscriptions.py +++ b/tests/replication/tcp/streams/test_thread_subscriptions.py @@ -88,15 +88,15 @@ def test_thread_subscription_updates(self) -> None: # We should have received all the expected rows in the right order # Filter the updates to only include thread subscription changes - received_rows = [ - upd - for upd in self.test_handler.received_rdata_rows - if upd[0] == ThreadSubscriptionsStream.NAME + received_thread_subscription_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == ThreadSubscriptionsStream.NAME ] # Verify all the thread subscription updates for thread_id in updates: - (stream_name, token, row) = received_rows.pop(0) + (stream_name, token, row) = received_thread_subscription_rows.pop(0) self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME) self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE) self.assertEqual(row.user_id, "@test_user:example.org") @@ -104,14 +104,14 @@ def test_thread_subscription_updates(self) -> None: self.assertEqual(row.event_id, thread_id) # Verify the last update in the different room - (stream_name, token, row) = received_rows.pop(0) + (stream_name, token, row) = received_thread_subscription_rows.pop(0) self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME) self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE) self.assertEqual(row.user_id, "@test_user:example.org") self.assertEqual(row.room_id, other_room_id) self.assertEqual(row.event_id, other_thread_root_id) - self.assertEqual([], received_rows) + self.assertEqual([], received_thread_subscription_rows) def test_multiple_users_thread_subscription_updates(self) -> None: """Test replication with thread subscription updates for multiple users""" @@ -138,18 +138,18 @@ def test_multiple_users_thread_subscription_updates(self) -> None: # We should have received all the expected rows # Filter the updates to only include thread subscription changes - received_rows = [ - upd - for upd in self.test_handler.received_rdata_rows - if upd[0] == ThreadSubscriptionsStream.NAME + received_thread_subscription_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == ThreadSubscriptionsStream.NAME ] # Should have one update per user - self.assertEqual(len(received_rows), len(users)) + self.assertEqual(len(received_thread_subscription_rows), len(users)) # Verify all updates for i, user_id in enumerate(users): - (stream_name, token, row) = received_rows[i] + (stream_name, token, row) = received_thread_subscription_rows[i] self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME) self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE) self.assertEqual(row.user_id, user_id) diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py index cb07e93d6b1..d6fd9f91edc 100644 --- a/tests/replication/tcp/streams/test_to_device.py +++ b/tests/replication/tcp/streams/test_to_device.py @@ -21,7 +21,10 @@ import logging import synapse -from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT +from synapse.replication.tcp.streams._base import ( + _STREAM_UPDATE_TARGET_ROW_COUNT, + ToDeviceStream, +) from synapse.types import JsonDict from tests.replication._base import BaseStreamTestCase @@ -82,7 +85,12 @@ def test_to_device_stream(self) -> None: ) # replication is disconnected so we shouldn't get any updates yet - self.assertEqual([], self.test_handler.received_rdata_rows) + received_to_device_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == ToDeviceStream.NAME + ] + self.assertEqual([], received_to_device_rows) # now reconnect to pull the updates self.reconnect() @@ -90,7 +98,15 @@ def test_to_device_stream(self) -> None: # we should receive the fact that we have to_device updates # for user1 and user2 - received_rows = self.test_handler.received_rdata_rows - self.assertEqual(len(received_rows), 2) - self.assertEqual(received_rows[0][2].entity, user1) - self.assertEqual(received_rows[1][2].entity, user2) + received_to_device_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == ToDeviceStream.NAME + ] + self.assertEqual( + len(received_to_device_rows), + 2, + "Expected two rows in the to_device stream", + ) + self.assertEqual(received_to_device_rows[0][2].entity, user1) + self.assertEqual(received_to_device_rows[1][2].entity, user2) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index e2b22991067..df91416b9be 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -19,7 +19,6 @@ # # import logging -from unittest.mock import Mock from synapse.handlers.typing import RoomMember, TypingWriterHandler from synapse.replication.tcp.streams import TypingStream @@ -27,6 +26,8 @@ from tests.replication._base import BaseStreamTestCase +logger = logging.getLogger(__name__) + USER_ID = "@feeling:blue" USER_ID_2 = "@da-ba-dee:blue" @@ -35,10 +36,6 @@ class TypingStreamTestCase(BaseStreamTestCase): - def _build_replication_data_handler(self) -> Mock: - self.mock_handler = Mock(wraps=super()._build_replication_data_handler()) - return self.mock_handler - def test_typing(self) -> None: typing = self.hs.get_typing_handler() assert isinstance(typing, TypingWriterHandler) @@ -47,51 +44,74 @@ def test_typing(self) -> None: # update to fetch. typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) + # Not yet connected: no rows should yet have been received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # Reconnect self.reconnect() typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) - - self.reactor.advance(0) + # Pull in the updates + self.replicate() # We should now see an attempt to connect to the master request = self.handle_http_replication_attempt() - self.assert_request_is_get_repl_stream_updates(request, "typing") - - self.mock_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] - self.assertEqual(stream_name, "typing") - self.assertEqual(1, len(rdata_rows)) - row: TypingStream.TypingStreamRow = rdata_rows[0] - self.assertEqual(ROOM_ID, row.room_id) - self.assertEqual([USER_ID], row.user_ids) + self.assert_request_is_get_repl_stream_updates(request, TypingStream.NAME) + + # Filter the updates to only include typing changes + received_typing_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == TypingStream.NAME + ] + self.assertEqual( + len(received_typing_rows), + 1, + "Expected exactly one row for the typing stream", + ) + (stream_name, token, row) = received_typing_rows[0] + self.assertEqual(stream_name, TypingStream.NAME) + self.assertIsInstance(row, TypingStream.ROW_TYPE) + self.assertEqual(row.room_id, ROOM_ID) + self.assertEqual(row.user_ids, [USER_ID]) + # Clear out the received rows that we've checked so we can check for new ones later + self.test_handler.received_rdata_rows.clear() # Now let's disconnect and insert some data. self.disconnect() - self.mock_handler.on_rdata.reset_mock() - typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False) - self.mock_handler.on_rdata.assert_not_called() + # Not yet connected: no rows should yet have been received + self.assertEqual([], self.test_handler.received_rdata_rows) + # Now reconnect and pull the updates self.reconnect() - self.pump(0.1) + self.replicate() # We should now see an attempt to connect to the master request = self.handle_http_replication_attempt() - self.assert_request_is_get_repl_stream_updates(request, "typing") + self.assert_request_is_get_repl_stream_updates(request, TypingStream.NAME) # The from token should be the token from the last RDATA we got. assert request.args is not None self.assertEqual(int(request.args[b"from_token"][0]), token) - self.mock_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] - self.assertEqual(stream_name, "typing") - self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] - self.assertEqual(ROOM_ID, row.room_id) - self.assertEqual([], row.user_ids) + received_typing_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == TypingStream.NAME + ] + self.assertEqual( + len(received_typing_rows), + 1, + "Expected exactly one row for the typing stream", + ) + (stream_name, token, row) = received_typing_rows[0] + self.assertEqual(stream_name, TypingStream.NAME) + self.assertIsInstance(row, TypingStream.ROW_TYPE) + self.assertEqual(row.room_id, ROOM_ID) + self.assertEqual(row.user_ids, []) def test_reset(self) -> None: """ @@ -116,33 +136,47 @@ def test_reset(self) -> None: # update to fetch. typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) + # Not yet connected: no rows should yet have been received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # Now reconnect to pull the updates self.reconnect() typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) - - self.reactor.advance(0) + # Pull in the updates + self.replicate() # We should now see an attempt to connect to the master request = self.handle_http_replication_attempt() self.assert_request_is_get_repl_stream_updates(request, "typing") - self.mock_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] - self.assertEqual(stream_name, "typing") - self.assertEqual(1, len(rdata_rows)) - row: TypingStream.TypingStreamRow = rdata_rows[0] - self.assertEqual(ROOM_ID, row.room_id) - self.assertEqual([USER_ID], row.user_ids) + received_typing_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == TypingStream.NAME + ] + self.assertEqual( + len(received_typing_rows), + 1, + "Expected exactly one row for the typing stream", + ) + (stream_name, token, row) = received_typing_rows[0] + self.assertEqual(stream_name, TypingStream.NAME) + self.assertIsInstance(row, TypingStream.ROW_TYPE) + self.assertEqual(row.room_id, ROOM_ID) + self.assertEqual(row.user_ids, [USER_ID]) # Push the stream forward a bunch so it can be reset. for i in range(100): typing._push_update( member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True ) - self.reactor.advance(0) + # Pull in the updates + self.replicate() # Disconnect. self.disconnect() + self.test_handler.received_rdata_rows.clear() # Reset the typing handler self.hs.get_replication_streams()["typing"].last_token = 0 @@ -155,30 +189,34 @@ def test_reset(self) -> None: ) typing._reset() - # Reconnect. + # Now reconnect and pull the updates self.reconnect() - self.pump(0.1) + self.replicate() # We should now see an attempt to connect to the master request = self.handle_http_replication_attempt() self.assert_request_is_get_repl_stream_updates(request, "typing") - # Reset the test code. - self.mock_handler.on_rdata.reset_mock() - self.mock_handler.on_rdata.assert_not_called() - # Push additional data. typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False) - self.reactor.advance(0) - - self.mock_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] - self.assertEqual(stream_name, "typing") - self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] - self.assertEqual(ROOM_ID_2, row.room_id) - self.assertEqual([], row.user_ids) - + # Pull the updates + self.replicate() + + received_typing_rows = [ + row + for row in self.test_handler.received_rdata_rows + if row[0] == TypingStream.NAME + ] + self.assertEqual( + len(received_typing_rows), + 1, + "Expected exactly one row for the typing stream", + ) + (stream_name, token, row) = received_typing_rows[0] + self.assertEqual(stream_name, TypingStream.NAME) + self.assertIsInstance(row, TypingStream.ROW_TYPE) + self.assertEqual(row.room_id, ROOM_ID_2) + self.assertEqual(row.user_ids, []) # The token should have been reset. self.assertEqual(token, 1) finally: diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index e684c6c1613..9a3b44219db 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -110,13 +110,13 @@ def test_initialise_reserved_users(self) -> None: self.assertGreater(timestamp, 0) # Test that users with reserved 3pids are not removed from the MAU table - # XXX some of this is redundant. poking things into the config shouldn't - # work, and in any case it's not obvious what we expect to happen when - # we advance the reactor. - self.hs.config.server.max_mau_value = 0 + # + # The `start_phone_stats_home()` looping call will cause us to run + # `reap_monthly_active_users` after the time has advanced self.reactor.advance(FORTY_DAYS) - self.hs.config.server.max_mau_value = 5 + # I guess we call this one more time for good measure? Perhaps because + # previously, the phone home stats weren't running in tests? self.get_success(self.store.reap_monthly_active_users()) active_count = self.get_success(self.store.get_monthly_active_count()) diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py index ab21a5dde4f..1d450f82512 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py @@ -75,7 +75,7 @@ class CommonMetricsTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.metrics_manager = hs.get_common_usage_metrics_manager() - self.get_success(self.metrics_manager.setup()) + self.metrics_manager.setup() def test_dau(self) -> None: """Tests that the daily active users count is correctly updated."""